复合函数的前向微分与反向自动微分计算
复合函数的前向微分与反向自动微分计算
关于
- 首次发表日期:2024-09-13
- 参考:
- https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation
- Calculus Early Transcendentals 9e - James Stewart (2020)
- https://en.wikipedia.org/wiki/Automatic_differentiation
- 水平有限,如有错误,请不吝指出
前向与反向自动微分:数学
先复习一下微积分求导法则
微积分求导法则复习
乘法法则
\[f(x) = u(x) \times v(x)
\]
\[\begin{aligned}
\frac{dy}{dx} &= \frac{du}{dx} \times v + \frac{dv}{dx} \times u \\
f'(x) &= u'v + v'u
\end{aligned}
\]
\[\begin{aligned}
f(x)&=(3 x-5) \times(4 x+7) \\
u&=3 x-5 \quad v=4 x+7 \\
u^{\prime}&=3 \quad v^{\prime}=4 \\
f^{\prime}(x)&=3(4 x+7)+4(3 x-5) \\
&=12 x+21+12 x-20=24 x+1 \\
&=24 x+1
\end{aligned}
\]
除法法则
\[f(x) = \frac{u(x)}{v(x)}
\]
\[\begin{aligned}
f'(x) &= \frac{u'v - v'u}{v^2} \\
\frac{dy}{dx} &= \frac{\frac{du}{dx}v - \frac{dv}{dx}u}{v^2}
\end{aligned}
\]
\[\begin{aligned}
f(x)&=\frac{3 x-5}{4 x+7} \\
u&=3 x-5 \quad v=4 x+7 \\
u^{\prime}&=3 \quad v^{\prime}=4 \\
f^{\prime}(x)&=\frac{3(4 x+7)-4(3 x-5)}{(4 x+7)^2} \\
&=\frac{12 x+21-12 x+20}{(4 x+7)^2} \\
&=\frac{41}{(4 x+7)^2}
\end{aligned}
\]
cos和sin求导
\[\begin{aligned}
y &= \sin(x) \\
\frac{dy}{dx} &= \cos(x)
\end{aligned}
\]
\[\begin{aligned}
y = \cos(x) \\
\frac{dy}{dx} = -\sin(x)
\end{aligned}
\]
链式法则(单变量复合函数)
\[y = f(u) \quad u = f(x)
\]
\[\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx}
\]
\[\begin{aligned}
y&=(2 x+4)^3 \\
y&=u^3 \text { and } u=2 x+4 \\
\frac{d y}{d u}&=3 u^2 \quad \frac{d u}{d x}=2 \\
\frac{d y}{d x}&=3 u^2 \times 2=2 \times 3(2 x+4)^2 \\
&=6(2 x+4)^2
\end{aligned}
\]
多变量链式法则(Case 1)
\[\begin{aligned}
z &= f(x,y) \\
x &= g(t) \\
y &= h(t) \\
\end{aligned}
\]
\[\frac{d z}{d t}=\frac{\partial f}{\partial x} \frac{d x}{d t}+\frac{\partial f}{\partial y} \frac{d y}{d t}
\]
多变量链式法则(Case 2)
\[\begin{aligned}
z &= f(x,y) \\
x & = g(s,t) \\
y &= h(s,t)
\end{aligned}
\]
\[\frac{\partial z}{\partial s}=\frac{\partial z}{\partial x} \frac{\partial x}{\partial s}+\frac{\partial z}{\partial y} \frac{\partial y}{\partial s} \quad \frac{\partial z}{\partial t}=\frac{\partial z}{\partial x} \frac{\partial x}{\partial t}+\frac{\partial z}{\partial y} \frac{\partial y}{\partial t}
\]
当计算\(\frac{\partial z}{\partial s}\)时,我们保持(hold)\(t\) 固定并计算 \(z\) 对 \(s\) 的普通导数,即应用多变量链式法则(Case 1)。计算\(\frac{\partial z}{\partial t}\)时同理。
多变量链式法则(广义版)
\[\begin{aligned}
u &= f(x_1, x_2, \ldots, x_n) \\
x_k &= g(t_1, t_2, \ldots, t_m) \qquad \text{for } 1 \leq k \leq n
\end{aligned}
\]
\[\begin{aligned}
&\frac{\partial u}{\partial t_i}=\frac{\partial u}{\partial x_1} \frac{\partial x_1}{\partial t_i}+\frac{\partial u}{\partial x_2} \frac{\partial x_2}{\partial t_i}+\cdots+\frac{\partial u}{\partial x_n} \frac{\partial x_n}{\partial t_i}
\end{aligned}
\qquad \text{for } 1 \leq i \leq m
\]
复合函数,偏微分,链式法则,前向和反向自动微分
前向与反向的计算顺序
对于组合函数:
\[\begin{aligned}
y & =f(g(h(x)))=f\left(g\left(h\left(w_0\right)\right)\right)=f\left(g\left(w_1\right)\right)=f\left(w_2\right)=w_3 \\
w_0 & =x \\
w_1 & =h\left(w_0\right) \\
w_2 & =g\left(w_1\right) \\
w_3 & =f\left(w_2\right)=y
\end{aligned}
\]
链式法则将给出:
\[\begin{aligned}
\frac{\partial y}{\partial x}&=\frac{\partial y}{\partial w_2} \frac{\partial w_2}{\partial w_1} \frac{\partial w_1}{\partial x}=\frac{\partial f\left(w_2\right)}{\partial w_2} \frac{\partial g\left(w_1\right)}{\partial w_1} \frac{\partial h\left(w_0\right)}{\partial x}
\end{aligned}
\]
计算顺序:
- 前向微分计算时 ,先计算\(\partial w_1 / \partial x\),然后计算\(\partial w_2/\partial w_1\),最后计算\(\partial y / \partial w_2\)
- 反向微分计算时,先计算\(\partial y / \partial w_2\),然后计算\(\partial w_2/\partial w_1\),最后计算\(\partial w_1 / \partial x\)
前向微分
对于组合函数:
\[\begin{aligned}
r &= ? \\
s &= ? \\
t &= ? \\
x &= g(r,s,t) \\
y & = h(r,s,t) \\
z &= i(r,s,t) \\
u &= f(x,y,z)
\end{aligned}
\]
前向微分计算:
\[\begin{aligned}
\frac{\partial r}{\partial v} &= ? \\
\frac{\partial s}{\partial v} &= ? \\
\frac{\partial t}{\partial v} &= ? \\
\\
\frac{\partial x}{\partial v} &= \frac{\partial x}{\partial r}\frac{\partial r}{\partial v} + \frac{\partial x}{\partial s}\frac{\partial s}{\partial v} + \frac{\partial x}{\partial t}\frac{\partial t}{\partial v} \\
\frac{\partial y}{\partial v} &= \frac{\partial y}{\partial r}\frac{\partial r}{\partial v} + \frac{\partial y}{\partial s}\frac{\partial s}{\partial v} + \frac{\partial y}{\partial t}\frac{\partial t}{\partial v} \\
\frac{\partial z}{\partial v} &= \frac{\partial z}{\partial r}\frac{\partial r}{\partial v} + \frac{\partial z}{\partial s}\frac{\partial s}{\partial v} + \frac{\partial z}{\partial t}\frac{\partial t}{\partial v} \\
\\
\frac{\partial u}{\partial v}&=\frac{\partial u}{\partial x} \frac{\partial x}{\partial v}+\frac{\partial u}{\partial y} \frac{\partial y}{\partial v}+\frac{\partial u}{\partial z} \frac{\partial z}{\partial v}
\end{aligned}
\]
当\(v=r\),即将\(r\)作为独立变量并将\(s\)和\(t\)固定时,可得
\[\begin{aligned}
\frac{\partial r}{\partial v} &= 1 \\
\frac{\partial s}{\partial v} &= 0 \\
\frac{\partial t}{\partial v} &= 0 \\
\frac{\partial u}{\partial r}&=\frac{\partial u}{\partial x} \frac{\partial x}{\partial r}+\frac{\partial u}{\partial y} \frac{\partial y}{\partial r}+\frac{\partial u}{\partial z} \frac{\partial z}{\partial r}
\end{aligned}
\]
当\(v=s\),即将\(s\)作为独立变量并将\(r\)和\(t\)固定时,可得
\[\begin{aligned}
\frac{\partial r}{\partial v} &= 0 \\
\frac{\partial s}{\partial v} &= 1 \\
\frac{\partial t}{\partial v} &= 0 \\
\frac{\partial u}{\partial s}&=\frac{\partial u}{\partial x} \frac{\partial x}{\partial s}+\frac{\partial u}{\partial y} \frac{\partial y}{\partial s}+\frac{\partial u}{\partial z} \frac{\partial z}{\partial s}
\end{aligned}
\]
当\(v=t\),即将\(t\)作为独立变量并将\(s\)和\(r\)固定时,可得
\[\begin{aligned}
\frac{\partial r}{\partial v} &= 0 \\
\frac{\partial s}{\partial v} &= 0 \\
\frac{\partial t}{\partial v} &= 1 \\
\frac{\partial u}{\partial t}&=\frac{\partial u}{\partial x} \frac{\partial x}{\partial t}+\frac{\partial u}{\partial y} \frac{\partial y}{\partial t}+\frac{\partial u}{\partial z} \frac{\partial z}{\partial t}
\end{aligned}
\]
反向微分
对于组合函数:
\[\begin{aligned}
u_1 &= r(x_1, x_2) \\
u_2 &= s(x_1, x_2) \\
y_1 &= f(u_1, u_2) \\
y_2 &= g(u_1, u_2) \\
y_3 &= h(u_1, u_2)
\end{aligned}
\]
反向微分计算:
\[\begin{aligned}
\frac{\partial s}{\partial y_1} &= ? \\
\frac{\partial s}{\partial y_2} &= ? \\
\frac{\partial s}{\partial y_3} &= ? \\
\\
\frac{\partial s}{\partial u_1} &= \frac{\partial s}{\partial y_1}\frac{\partial y_1}{\partial u_1} + \frac{\partial s}{\partial y_2}\frac{\partial y_2}{\partial u_1} + \frac{\partial s}{\partial y_3}\frac{\partial y_3}{\partial u_1} \\
\frac{\partial s}{\partial u_2} &= \frac{\partial s}{\partial y_1}\frac{\partial y_1}{\partial u_2} + \frac{\partial s}{\partial y_2}\frac{\partial y_2}{\partial u_2} + \frac{\partial s}{\partial y_3}\frac{\partial y_3}{\partial u_2} \\
\\
\frac{\partial s}{\partial x_1} &= \frac{\partial s}{\partial u_1}\frac{\partial u_1}{\partial x_1} + \frac{\partial s}{\partial u_2}\frac{\partial u_2}{\partial x_1} \\
\frac{\partial s}{\partial x_2} &= \frac{\partial s}{\partial u_1}\frac{\partial u_1}{\partial x_x} + \frac{\partial s}{\partial u_2}\frac{\partial u_2}{\partial x_x}
\end{aligned}
\]
可以想象有一个函数\(s=function(y_1,y_2,y_3)\)
当\(s=y_1\),即将\(y_1\)作为独立变量并将\(y_2\)和\(y_3\)固定时,可得
\[\begin{aligned}
\frac{\partial s}{\partial y_1} &= 1 \\
\frac{\partial s}{\partial y_2} &= 0 \\
\frac{\partial s}{\partial y_3} &= 0 \\
\\
\frac{\partial s}{\partial u_1} &= \frac{\partial s}{\partial y_1}\frac{\partial y_1}{\partial u_1}\\
\frac{\partial s}{\partial u_2} &= \frac{\partial s}{\partial y_1}\frac{\partial y_1}{\partial u_2} \\
\\
\frac{\partial s}{\partial x_1} &= \frac{\partial s}{\partial u_1}\frac{\partial u_1}{\partial x_1} + \frac{\partial s}{\partial u_2}\frac{\partial u_2}{\partial x_1} \\
\frac{\partial s}{\partial x_2} &= \frac{\partial s}{\partial u_1}\frac{\partial u_1}{\partial x_x} + \frac{\partial s}{\partial u_2}\frac{\partial u_2}{\partial x_x}
\end{aligned}
\]
以例子说明自动微分的计算
例子
假设有2个输入变量(\(x_1\), \(x_2\))和2个输出变量(\(y_1\), \(y_2\)):
\[\begin{aligned}
m_1 &= x_1 \cdot x_2 + \sin(x_1) \\
m_2 &= 4x_1 + 2x_2 + \cos(x_2) \\
y_1 &= m_1 + m_2 \\
y_2 &= m_1 \cdot m_2
\end{aligned}
\tag{1}
\]
即:
\[\begin{aligned}
y_1 &= x_1 \cdot x_2 + \sin(x_1) + 4x_1 + 2x_2 + \cos(x_2) \\
y_2 &= (x_1 + x_2 + \sin(x_1)) \cdot (4x_1 + 2x_2 + \cos(x_2))
\end{aligned}
\]
其中:
\[\begin{aligned}
\frac{\partial y_1}{\partial x_1} &= x_2 + \cos(x_1) + 4 \\
\frac{\partial y_1}{\partial x_2} &= x_1 + 2 - \sin(x_2) \\
\frac{\partial y_2}{\partial x_1} &= (x_2 + \cos(x_1)) \cdot m_2 + m_1 \cdot 4
\end{aligned}
\]
接下来,我们将以这个例子说明如何进行前向自动微分和反向自动微分
前向自动微分
我们将用到如下的链式法则:
\[\begin{align}
\frac{\partial w}{\partial t}
&= \sum_i \left(\frac{\partial w}{\partial u_i} \cdot \frac{\partial u_i}{\partial t}\right) \\
&= \frac{\partial w}{\partial u_1} \cdot \frac{\partial u_1}{\partial t} + \frac{\partial w}{\partial u_2} \cdot \frac{\partial u_2}{\partial t} + \cdots
\end{align}
\]
其中:
- \(w\)表示输出
- 在例子中,为\(y_1\)或者\(y_2\)
- \(u_i\)表示直接影响\(w\)的输入变量
- 在例子中,为\(a\)和\(b\)
- \(t\)表示有待给出的输入变量
- 在例子中,为\(x_1\)或者\(x_2\)其中之一
在计算之前,我们先将公式(1)分解为简单的算子计算:
\[\begin{aligned}
x_1 &= ? \\
x_2 &= ? \\
\\
a &= x_1 \cdot x_2 \\
b &= \sin(x_1) \\
\\
c &= 4x_1 + 2x_2 \\
d &= \cos(x_2) \\
\\
m_1 &= a + b \\
m_2 &= c + d \\
\\
y_1 &= m_1 + m_2 \\
y_2 &= m_1 \cdot m_2
\end{aligned}
\tag{2}
\]
现在我们对有待给出的变量\(t\)求导:
\[\begin{aligned}
\frac{\partial x_1}{\partial t} &= ? \\
\frac{\partial x_2}{\partial t} &= ? \\
\\
\frac{\partial a}{\partial t} &= x_2\frac{\partial x_1}{\partial t} + x_1 \frac{\partial x_2}{\partial t} \\
\frac{\partial b}{\partial t} &= \cos(x_1) \frac{\partial x_1}{\partial t} \\
\\
\frac{\partial c}{\partial t} &= 4\frac{\partial x_1}{\partial t} + 2 \frac{\partial x_2}{\partial t} \\
\frac{\partial d}{\partial t} &= -\sin(x_2)\frac{\partial x_2}{\partial t} \\
\\
\frac{\partial m_1}{\partial t} &= \frac{\partial a}{\partial t} + \frac{\partial b}{\partial t} \\
\frac{\partial m_2}{\partial t} &= \frac{\partial c}{\partial t} + \frac{\partial d}{\partial t} \\
\\
\frac{\partial y_1}{\partial t} &= \frac{\partial m_1}{\partial t} + \frac{\partial m_2}{\partial t} \\
\frac{\partial y_2}{\partial t} &= \frac{\partial m_1}{\partial t} \cdot m_2 + \frac{\partial m_2}{\partial t} \cdot m_1
\end{aligned}
\]
前面有提到\(t\)是有待给出的,现在是时候给出了:
- 将\(t=x_1\)代入以上公式,则\(\frac{\partial x_1}{\partial t} = 1\)而\(\frac{\partial x_2}{\partial t}=0\),然后可以计算\(\frac{\partial y_1}{\partial x_1}\)和\(\frac{\partial y_2}{\partial x_1}\)
\[\begin{aligned}
\frac{\partial x_1}{\partial t} &= 1 \\
\frac{\partial x_2}{\partial t} &= 0 \\
\\
\frac{\partial a}{\partial t} &= x_2\frac{\partial x_1}{\partial t} + x_1 \frac{\partial x_2}{\partial t} = x_2 \\
\frac{\partial b}{\partial t} &= \cos(x_1) \frac{\partial x_1}{\partial t} = \cos(x_1) \\
\\
\frac{\partial c}{\partial t} &= 4\frac{\partial x_1}{\partial t} + 2 \frac{\partial x_2}{\partial t} = 4 \\
\frac{\partial d}{\partial t} &= -\sin(x_2)\frac{\partial x_2}{\partial t} = 0\\
\\
\frac{\partial m_1}{\partial t} &= \frac{\partial a}{\partial t} + \frac{\partial b}{\partial t} = x_2 + \cos(x_1) \\
\frac{\partial m_2}{\partial t} &= \frac{\partial c}{\partial t} + \frac{\partial d}{\partial t} = 4 \\
\\
\frac{\partial y_1}{\partial t} &= \frac{\partial m_1}{\partial t} + \frac{\partial m_2}{\partial t} = x_2 + \cos(x_1) + 4 \\
\frac{\partial y_2}{\partial t} &= \frac{\partial m_1}{\partial t} \cdot m_2 + \frac{\partial m_2}{\partial t} \cdot m_1 = (x_2 + \cos(x_1)) \cdot m_2 + 4 \cdot m_1
\end{aligned}
\]
- 将\(t=x_2\)代入以上公式,则\(\frac{\partial x_1}{\partial t} = 0\)而\(\frac{\partial x_2}{\partial t}=1\),然后可以计算\(\frac{\partial y_1}{\partial x_2}\)和\(\frac{\partial y_2}{\partial x_2}\)
可以推断:
- 当有\(n\)个输入变量时(本例中有2个),需要计算\(n\)次上述公式。
- 假设神经网络中的输入是一张1280 x 720的图片,输出是51个浮点数,那么前向微分方法则需要计算921600次。
反向自动微分
我们将用到如下的链式法则:
\[\begin{align}
\frac{\partial s}{\partial u}
&= \sum_i \left(\frac{\partial w_i}{\partial u} \cdot \frac{\partial s}{\partial w_i}\right) \\
&= \frac{\partial w_1}{\partial u} \cdot \frac{\partial s}{\partial w_1} + \frac{\partial w_2}{\partial u} \cdot \frac{\partial s}{\partial w_2} + \cdots
\end{align}
\]
其中:
- \(u\) 表示输入变量
- \(w_i\) 表示依赖 \(u\) 的输出变量
- \(s\) 表示有待给出的变量
回顾拆解后的简单算子计算(2):
\[\begin{aligned}
x_1 &= ? \\
x_2 &= ? \\
\\
a &= x_1 \cdot x_2 \\
b &= \sin(x_1) \\
\\
c &= 4x_1 + 2x_2 \\
d &= \cos(x_2) \\
\\
m_1 &= a + b \\
m_2 &= c + d \\
\\
y_1 &= m_1 + m_2 \\
y_2 &= m_1 \cdot m_2
\end{aligned}
\tag{2}
\]
现在计算反向微分:
\[\begin{aligned}
\frac{\partial s}{\partial y_1} &= ? \\
\frac{\partial s}{\partial y_2} &= ? \\
\\
\frac{\partial s}{\partial m_1} &= \frac{\partial s}{\partial y_1} \frac{\partial y_1}{\partial m_1} + \frac{\partial s}{\partial y_2} \frac{\partial y_2}{\partial m_1} \\
\frac{\partial s}{\partial m_2} &= \frac{\partial s}{\partial y_1} \frac{\partial y_1}{\partial m_2} + \frac{\partial s}{\partial y_2} \frac{\partial y_2}{\partial m_2} \\
\\
\frac{\partial s}{\partial a} &= \frac{\partial s}{\partial m_1}\frac{\partial m_1}{\partial a} \\
\frac{\partial s}{\partial b} &= \frac{\partial s}{\partial m_1}\frac{\partial m_1}{\partial b} \\
\frac{\partial s}{\partial c} &= \frac{\partial s}{\partial m_2}\frac{\partial m_2}{\partial c} \\
\frac{\partial s}{\partial d} &= \frac{\partial s}{\partial m_2}\frac{\partial m_2}{\partial d} \\
\\
\frac{\partial s}{\partial x_1} &= \frac{\partial s}{\partial a}\frac{\partial a}{\partial x_1} + \frac{\partial s}{\partial b}\frac{\partial b}{\partial x_1} + \frac{\partial s}{\partial c}\frac{\partial c}{\partial x_1} \\
\frac{\partial s}{\partial x_2} &= \frac{\partial s}{\partial a}\frac{\partial a}{\partial x_1} + \frac{\partial s}{\partial c}\frac{\partial c}{\partial x_1} + \frac{\partial s}{\partial d}\frac{\partial d}{\partial x_1}
\end{aligned}
\]
当\(s=y_1\)时:
\[\begin{aligned}
\frac{\partial s}{\partial y_1} &= 1 \\
\frac{\partial s}{\partial y_2} &= 0 \\
\\
\frac{\partial s}{\partial m_1} &= \frac{\partial s}{\partial y_1} \frac{\partial y_1}{\partial m_1} + \frac{\partial s}{\partial y_2} \frac{\partial y_2}{\partial m_1} = 1 \\
\frac{\partial s}{\partial m_2} &= \frac{\partial s}{\partial y_1} \frac{\partial y_1}{\partial m_2} + \frac{\partial s}{\partial y_2} \frac{\partial y_2}{\partial m_2} = 1 \\
\\
\frac{\partial s}{\partial a} &= \frac{\partial s}{\partial m_1}\frac{\partial m_1}{\partial a} = 1 \\
\frac{\partial s}{\partial b} &= \frac{\partial s}{\partial m_1}\frac{\partial m_1}{\partial b} = 1 \\
\frac{\partial s}{\partial c} &= \frac{\partial s}{\partial m_2}\frac{\partial m_2}{\partial c} = 1 \\
\frac{\partial s}{\partial d} &= \frac{\partial s}{\partial m_2}\frac{\partial m_2}{\partial d} = 1 \\
\\
\frac{\partial s}{\partial x_1} &= \frac{\partial s}{\partial a}\frac{\partial a}{\partial x_1} + \frac{\partial s}{\partial b}\frac{\partial b}{\partial x_1} + \frac{\partial s}{\partial c}\frac{\partial c}{\partial x_1} = 1 \cdot x_2 + 1 \cdot \cos(x_1) + 1 \cdot 4 = x_2 + \cos(x_1) + 4 \\
\frac{\partial s}{\partial x_2} &= \frac{\partial s}{\partial a}\frac{\partial a}{\partial x_2} + \frac{\partial s}{\partial c}\frac{\partial c}{\partial x_2} + \frac{\partial s}{\partial d}\frac{\partial d}{\partial x_2} = 1 \cdot x_1 + 1 \cdot 2 + 1 \cdot (-\sin(x_2)) = x_1 + 2 -\sin(x_2)
\end{aligned}
\]
同理可以计算当\(s=y_2\)时。
可以推断:
- 当有\(n\)个输出变量时(本例中有2个),需要计算\(n\)次上述公式。
- 假设神经网络中的输入是一张1280 x 720的图片,输出是51个浮点数,那么反向微分方法则需要计算51次。