13-反向传播法求梯度

反向传播法求梯度

利用计算图求梯度是一种比较方便又快速的方法,如何利用计算图求梯度?先回忆一下计算图:

z = x 2 + y 2 z=x^2+y^2 z=x2+y2 为例:

img

  • 计算图以箭头和节点构成,正向传播时,得到的结果是 z = x 2 + y 2 z=x^2+y^2 z=x2+y2

  • 反向传播时,得到的结果是: ∂ L ∂ z × 2 x \frac{\partial L}{\partial z}\times 2x zL×2x ∂ L ∂ z × 2 y \frac{\partial L}{\partial z}\times 2y zL×2y

  • 仔细一看,令 ∂ L ∂ z = 1 \frac{\partial L}{\partial z}=1 zL=1 不就得到梯度为 ( 2 x , 2 y ) (2x,2y) (2x,2y) 了吗!(注意这里的x, y是正向传播的x和y)
    为什么令 ∂ L ∂ z = 1 \frac{\partial L}{\partial z}=1 zL=1 就可以得到梯度了呢?实际上在计算图中,L才是输出函数,而z只是中间变量,但在这里,z也是输出函数,所以 L=z, 因此有 ∂ L ∂ z = 1 \frac{\partial L}{\partial z}=1 zL=1

  • 认识到这点,现在就可以写代码实现求梯度了,这里用一个类来实现:

class SqrtWithAdd:
    def __init__(self):
        self.x = None
        self.y = None
        
    def forward(self, x, y):
        self.x = x
        self.y = y
        out = x**2 + y**2
        return out
    
    def backward(self, dout=1): 
        dx = dout * 2*self.x
        dy = dout * 2*self.y
        return dx,dy           #(dx, dy)就是梯度
        

现在利用这个类来求一下在点 (2, 3) 处的梯度:

sqrt_with_add = SqrtWithAdd()    #实例化类
sqrt_with_add.forward(2,3)       #先进行正向传播
grad = sqrt_with_add.backward()  #求梯度

print(grad)

#输出:
(4, 6)

经手算验证,结果正确。

梯度下降求最小值

既然用反向传播的方法求出了梯度值,那么现在就想用这个方法结合梯度下降法来求一下函数最小值。

先把公式写一下:
x = x − η ∂ f ∂ x y = y − η ∂ f ∂ y x = x - \eta \frac{\partial f}{\partial x}\\y = y - \eta \frac{\partial f}{\partial y} x=xηxfy=yηyf
之前讲过了, η \eta η 是学习率。

代码如下:

def gradient_descent(init_x,init_y, lr=0.01, step_num=100):
    x = init_x
    y = init_y
    sqrt_with_add = SqrtWithAdd()    #创建实例
    for i in range(step_num):
        sqrt_with_add.forward(x,y)   #正向传播
        dx,dy = sqrt_with_add.backward()     #反向传播求梯度
        x -= lr * dx
        y -= lr * dy
    return x,y

#调用函数求最小值位置
x,y = gradient_descent(init_x= 3,init_y= 4) #设置开始起点(3,4)

print('[{}, {}]'.format(x,y))
        

输出:

[0.39785866768425965, 0.5304782235790126]

结果可以通过调整学习率 lr 和 学习次数 step_num 来接近最佳位置。

这里举的例子是二维的,可以通过调整上面的类的输入值的维数来使其变为可以求 n维的。


最后总结一下反向传播法求梯度与数值微分法求梯度的区别

反向传播求梯度实际上用的是解析法来求导,跟自己手算求梯度是一样的,记得求导公式就可以;

而数值微分法求的梯度实际上用的是导数定义法来求

posted @ 2020-08-02 00:11  aJream  阅读(92)  评论(0编辑  收藏  举报