lesson-6-Autograd

计算图

假设有这样一个算式:y=(x+w)(w+1),则具体细化为:

  • a=x+w
  • b=w+1
  • y=ab

得到计算图如下:

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.add(w, 1)    # retain_grad()
y = torch.mul(a, b)

# 反向传播计算所有节点梯度
y.backward()

# 查看叶子结点
print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)
# 查看梯度
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)
# 查看 grad_fn
print("grad_fn:\n", w.grad_fn, x.grad_fn, a.grad_fn, b.grad_fn, y.grad_fn)
is_leaf:
 True True False False False
gradient:
 tensor([5.]) tensor([2.]) None None None
grad_fn:
 None None <AddBackward0 object at 0x7fc556fa75e0> <AddBackward0 object at 0x7fc5575c0820> <MulBackward0 object at 0x7fc5570fb340>
  1. 判断是否是叶节点?

    这里w,x是叶节点,y是根节点

  2. 计算梯度

    对于wy关于w的导数计算如下:

    yw=yaaw+ybbw=(w+1)1+(x+w)1=x+2w+1=2+21+1=5

    类似地,y关于x的导数计算如下:

    yx=yaax+ybbx=(w+1)1+(x+w)0=w+1=2

    与输出结果相符。

  3. 计算函数

    对于中间结果节点,会保存计算方式。wx叶节点没有,而aby含有。

  • 补充知识点1:非叶子结点在梯度反向传播结束后释放

    只有叶子节点的梯度得到保留,中间变量的梯度默认不保留;在pytorch中,非叶子结点的梯度在反向传播结束之后就会被释放掉,如果需要保留的话可以对该结点设置retain_grad()

  • 补充知识点2:grad_fn是用来记录创建张量时所用到的运算,在链式求导法则中会使用到

自动求导

目标函数调用backward()方法即可实现反向传播,自动求导,并在tensor上保留梯度。

y.backward(retain_graph=True)
print(w.grad)
y.backward()
print(w.grad)
tensor([5.])
tensor([10.])

知识点1

梯度不会自动清零。多次调用backward(),梯度会累加。

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

for i in range(4):
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    y.backward()   
    print(w.grad)  # 梯度不会自动清零,数据会累加, 通常需要采用 optimizer.zero_grad() 完成对参数的梯度清零

#     w.grad.zero_()

tensor([5.])
tensor([10.])
tensor([15.])
tensor([20.])

手动设置清零

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

for i in range(4):
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    y.backward()   
    print(w.grad)
    # 在计算完梯度之后,清零梯度
    w.grad.zero_()  
tensor([5.])
tensor([5.])
tensor([5.])
tensor([5.])

知识点2

依赖于叶子结点的结点,requires_grad默认为True

w = torch.tensor([1.], requires_grad=True) 
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

print(a.requires_grad, b.requires_grad, y.requires_grad)
print(a.is_leaf, b.is_leaf, y.is_leaf)

True True True
False False False

知识点3

叶子张量不可以执行in-place操作。

叶子结点不可执行in-place,因为计算图的backward过程都依赖于叶子结点的计算,所有的偏微分计算所需要用到的数据都是基于w和x(叶子结点),因此叶子结点不允许in-place操作。

即不能直接修改叶子结点。

a = torch.ones((1, ))
print(id(a), a)

a = a + torch.ones((1, ))
print(id(a), a)

a += torch.ones((1, ))
print(id(a), a)
140485538301904 tensor([1.])
140485538303264 tensor([2.])
140485538303264 tensor([3.])
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

w.add_(1)

y.backward()

---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

Cell In[21], line 8
      5 b = torch.add(w, 1)
      6 y = torch.mul(a, b)
----> 8 w.add_(1)
     10 y.backward()


RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

知识点4

detach 的作用

通过以上知识,我们知道计算图中的张量是不能随便修改的,否则会造成计算图的backward计算错误,那有没有其他方法能修改呢?当然有,那就是detach()

detach的作用是:从计算图中剥离出“数据”,并以一个新张量的形式返回,并且新张量与旧张量共享数据,简单的可理解为做了一个别名。 请看下例的w,detach后对w_detach修改数据,w同步地被改为了999

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

y.backward()

w_detach = w.detach()
w_detach.data[0] = 999
# w同步地被改为999
print(w)
tensor([999.], requires_grad=True)

知识点5

autograd自动构建计算图过程中会保存一系列中间变量,以便于backward的计算,这就必然需要花费额外的内存和时间。

而并不是所有情况下都需要backward,例如推理的时候,因此可以采用上下文管理器——torch.no_grad()来管理上下文,让pytorch不记录相应的变量,以加快速度和节省空间。

简单来说,推理时用with torch.no_grad():包裹计算过程。

posted @   superzzh  阅读(3)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异
· 三行代码完成国际化适配,妙~啊~
点击右上角即可分享
微信分享提示