Variable类
Variable类
Variable类在pytorch中定义,以下代码演示Variable的用法:
import torch
from torch.autograd import Variable
import numpy
tensor = torch.FloatTensor([[1, 2], [3, 4]])
var = Variable(tensor, requires_grad=True)
print('original matrix:', tensor)
print('\nVariable:', var)
t_out = torch.mean(tensor*tensor)
v_out = torch.mean(var*var)
print('\norigin squared and mean:', t_out)
print('\nVariable squared and mean:', v_out)
v_out.backward()
print('\ngradient after backward', var.grad)
print('\nvar.data is a tensor:', var.data)
print('\nthe tensor can become a numpy:', var.data.numpy())
输出结果:
original matrix: tensor([[1., 2.],
[3., 4.]])
Variable: tensor([[1., 2.],
[3., 4.]], requires_grad=True)
origin squared and mean: tensor(7.5000)
Variable squared and mean: tensor(7.5000, grad_fn=<MeanBackward0>)
gradient after backward tensor([[0.5000, 1.0000],
[1.5000, 2.0000]])
var.data is a tensor: tensor([[1., 2.],
[3., 4.]])
the tensor can become a numpy: [[1. 2.]
[3. 4.]]
Variable对象可以进行与tensor相同的运算。
同时能够记录梯度。
它的一个属性data是tensor类型。