反向传播算法
0 梯度更新函式
1 梯度
神经网络参数如下:
θ = {w1, w2, ... , b1, b2, ...}
权重梯度如下:
为了更好且有效的计算梯度,我们使用反向传播算法。
2 链式法则
3 反向传播
损失函数(Loss function)是定义在单个训练样本上的,比如我们想要分类,就是预测的类别和实际类别的区别,通常用L表示。
总体损失函数(Total loss function)是定义在整个训练集上面的,也就是所有样本的误差的总和。也就是平时我们反向传播需要最小化的值。
对于L(θ)就是所有ln的损失之和,所以如果要算每个L(θ)的偏微分,我们只要算每个l^n的偏微分,再把所有l^n偏微分的结果加起来就是L(θ)的偏微分。
4 实例
计算梯度分为两个步骤:
- 计算∂z/∂w(Forward pass的部分)
- 计算∂l/∂z ( Backward pass的部分 )
5 Forward pass
6 Backward pass
7 总结
我们的目标是要求计算Forward pass的部分和计算Backward pass的部分,然后把∂z/∂w和∂l/∂z相乘,我们就可以得到∂l/∂w,所有我们就可以得到神经网络中所有的参数,然后用梯度下降就可以不断更新,得到损失最小的函数。
8 简单代码实现
1 # -*- coding: utf-8 -*-
2 # @Time : 2022/11/28 16:16
3
4 import numpy as np
5
6 # 输入数据
7 x = np.array([[1, 1], [-1, 1], [-1, -0.5], [-1, 1.5]])
8 t = np.array([0, 1, 0, 1])
9 # print(x.shape[0])
10 w_init = np.array([1, 1])
11 b_init = 0
12 lr = 0.9
13 epoch = 0
14
15
16 # 前向传播
17 def forward(x, w, b):
18 val = np.matmul(x, w.T) + b
19 y_pre = np.zeros_like(val)
20 for i in range(x.shape[0]):
21 if val[i] > 0:
22 y_pre[i] = 1
23 else:
24 y_pre[i] = 0
25 return y_pre, val
26
27
28 # 计算loss损失
29 def loss_fuc(y_pre, w, b):
30 val = np.matmul(x, w.T) + b
31 loss = -(np.matmul(t - y_pre, val) / 4)
32 return loss
33
34
35 # parameter update
36 def grad_update(y_pre, w_old, b_old, lr):
37 grad_w = -(np.matmul(t-y_pre, x)) / 4.0
38 one = np.ones_like(t-y_pre).T
39 grad_b = -(np.matmul(t-y_pre, one) / 4.0)
40 w_new = w_old - lr * grad_w
41 b_new = b_old - lr * grad_b
42
43 return w_new, b_new
44
45
46 if __name__ == '__main__':
47
48 # 前向传播
49 y_pre, val = forward(x, w_init, b_init)
50
51 # 计算loss
52 loss_result = loss_fuc(y_pre, w_init, b_init)
53
54 # 打印输出结果
55 print("epoch:", epoch, "\n", "val:", val, "\n", "w_new:", w_init, "\n", "b_new:", b_init, "\n", "y_pre:", y_pre, "\n", "loss:", loss_result)
56
57 # 反向传播
58 while loss_result != 0:
59 # epoch递增
60 epoch += 1
61
62 # w,b参数更新
63 w_init, b_init = grad_update(y_pre, w_init, b_init, lr)
64
65 # 前向传播
66 y_pre, val = forward(x, w_init, b_init)
67
68 # 计算loss
69 loss_result = loss_fuc(y_pre, w_init, b_init)
70
71 # 打印输出结果
72 print("epoch:", epoch, "\n", "val:", val, "\n", "w_new:", w_init, "\n", "b_new:", b_init, "\n", "y_pre:", y_pre, "\n", "loss:", loss_result)
代码结果:
作者:kali
-------------------------------------------
个性签名:纸上学来终觉浅,绝知此事要躬行。
如果觉得这篇文章对你有小小的帮助的话,记得在右下角点个“推荐”哦,博主在此感谢!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异
· 三行代码完成国际化适配,妙~啊~