Tensor.item()
官方解释:张量自带函数,将张量转变为python的数值,它只能用于单个张量,如何用于多变量则为tolist()。
loss=torch.nn.MSELoss()
print(loss(torch.tensor(1.0),torch.tensor(1.0)))
print(loss(torch.tensor(1.0),torch.tensor(1.0)).item())
输出:
tensor(0.)
0.0
由此可知loss返回的是一维张量(MSE损失函数得出的结果是多损失数的平均值),通过item函数从张量变成了python数值。