Variable类

Variable类

Variable类在pytorch中定义,以下代码演示Variable的用法:

import torch
from torch.autograd import Variable
import numpy

tensor = torch.FloatTensor([[1, 2], [3, 4]])
var = Variable(tensor, requires_grad=True)
print('original matrix:', tensor)
print('\nVariable:', var)

t_out = torch.mean(tensor*tensor)
v_out = torch.mean(var*var)
print('\norigin squared and mean:', t_out)
print('\nVariable squared and mean:', v_out)

v_out.backward()
print('\ngradient after backward', var.grad)

print('\nvar.data is a tensor:', var.data)

print('\nthe tensor can become a numpy:', var.data.numpy())

输出结果:

original matrix: tensor([[1., 2.],
        [3., 4.]])

Variable: tensor([[1., 2.],
        [3., 4.]], requires_grad=True)

origin squared and mean: tensor(7.5000)

Variable squared and mean: tensor(7.5000, grad_fn=<MeanBackward0>)

gradient after backward tensor([[0.5000, 1.0000],
        [1.5000, 2.0000]])

var.data is a tensor: tensor([[1., 2.],
        [3., 4.]])

the tensor can become a numpy: [[1. 2.]
 [3. 4.]]

Variable对象可以进行与tensor相同的运算。

同时能够记录梯度。

它的一个属性data是tensor类型。

posted on 2021-09-02 17:36  菜小疯  阅读(76)  评论(0编辑  收藏  举报