LR梯度下降法MSE演练
同步进行一波网上代码搬砖, 先来个入门的线性回归模型训练, 基于梯度下降法来, 优化用 MSE 来做. 理论部分就不讲了, 网上一大堆, 我自己也是理解好多年了, 什么 偏导数, 梯度(多远函数一阶偏导数组成的向量) , 方向导数, 反方向(梯度下降) 这些基本的高数知识, 假设大家是非常清楚原理的.
如不清楚原理, 那就没有办法了, 只能自己补, 毕竟 ML 这块, 如果不清楚其数学原理, 只会有框架和导包, 那得是多门的无聊和无趣呀. 这里是搬运代码的, 当然, 我肯定是有改动的,基于我的经验, 做个小笔记, 方便自己后面遇到时, 直接抄呀.
01 采样数据
这里呢, 假设已知一个线性模型, 就假设已经基本训练好了一个, 比如是这样的.
现在为了更好模拟真实样本的观测误差, 给模型添加一个误差变量 (读作 \epsilon) , 然后想要搞成这样的.
现在来随机采样 100次, 得到 n=100 的样本训练数据集
import numpy as np
def data_sample(times=100):
"""数据集采样times次, 返回一个二维数组"""
for i in range(times):
# 随机采样输入 x, 一个数值 (均匀分布)
x = np.random.uniform(-10, 10)
# 采样高斯噪声(随机误差),正态分布
epsilon = np.random.normal(0, 0.01)
# 得到模型输出
y = 1.447 * x + 0.089 + epsilon
# 用生成器来生成或存储样本点
yield [x, y]
# test
# 将数据转为 np.array 的二维数组
data = np.array(list(data_sample()))
data 是这样的, 2D数组, 一共100行记录, 每一行表示一个样本点 (x, y).
array([[ 5.25161007, 7.6922458 ],
[ 9.00034456, 13.11119931],
[ 9.47485633, 13.80426132],
[ -4.3644416 , -6.2183884 ],
[ -3.35345323, -4.76625711],
[ -5.10494006, -7.30976062],
.....
[ -6.78834597, -9.73362456]]
02 计算误差 MSE
计算每个点 (xi, yi) 处的预测值 与 真实值 之差的平方 并 累加, 从而得到整个训练集上的均方误差损失值.
# y = w * x + b
def get_mse(w, b, points):
"""计算训练集的 MSE"""
# points 是一个二维数组, 每行表示一个样本
# 每个样本, 第一个数是 x, 第二个数是 y
loss = 0
for i in range(0,len(X)):
x = points[i, 0]
y = points[i, 1]
# 计算每个点的误差平方, 并进行累加
loss += (y - (w * x + b)) ** 2
# 用 总损失 / 总样本数 = 均方误差 mse
return loss / len(points)
样本是一个二维数组, 或者矩阵. 每一行, 表示一个样本, 每一列表示该样本的某个子特征
03 计算梯度
关于梯度, 即多元函数的偏导数向量, 这个方向是, 多元函数的最大导数方向 (变化率最大) 方向 (向量方向), 于是, 反方向, 则是函数变化率最小, 即极值点的地方呀, 就咱需要的, 所以称为, 梯度下降法嘛, 从数学上就非常好理解哦.
def step_gradient(b_current, w_current, points, lr):
# 计算误差函数在所有点的导数, 并更新 w, b
b_gradient = 0
w_gradinet = 0
n = len(points) # 样本数
for i in range(n):
# x, y 都是一个数值
x = points[i, 0]
y = points[i, 1]
# 损失函数对 b 的导数 g_b = 2/n * (wx+b-y) 数学推导的
b_gradient += (n/2) * ((w_current * x + b) - y)
# 损失函数对 w 的导数 g_w = 2/n (wx+b-y) x
w_gradinet += (n/2) * x * ((w_current * x + b) - y)
# 根据梯度下降法, 更新 w, b
new_w = w_current - (lr * b_gradient)
new_b = b_current - (lr * b_gradient)
return [new_w, new_b]
04 更新梯度 Epoch
根据第三步, 在算出误差函数在 w, b 的梯度后, 就可以通过 梯度下降法来更新 w,b 的值. 我们把对数据集的所有样本训练一次称为一个 Epoch, 共循环迭代 num_iterations 个 Epoch.
def gradient_descent(points, w, b, lr, max_iter):
"""梯度下降 Epoch"""
for step in range(max_iter):
# 计算梯度并更新一次
w, b = step_gradient(b, w, np.array(points),lr)
# 计算当前的 均方差 mse
loss = get_mes(w, b, points)
if step % 50 == 0:
# 每隔50次打印一次当前信息
print(f"iteration: {step} loss: {loss}, w:{w}, b:{b}")
# 返回最后一次的 w,b
return [w, b]
05 主函数
def main():
# 加载训练数据, 即通过真实模型添加高斯噪声得到的
lr = 0.01 # 学习率
init_b = 0
init_w = 0
max_iter = 500 # 最大Epoch=100次
# 用梯度下降法进行训练
w, b = gradient_descent(data, init_w, init_b, lr, max_iter)
# 计算出最优的均方差 loss
loss = get_mse(w, b, dataa)
print(f"Final loss: {loss}, w:{w}, b:{b}")
# 运行主函数
main()
iteration: 0 loss: 52624.8637745707, w:-37.451784525811654, b:-37.451784525811654
iteration: 50 loss: 8.751081967754209e+134, w:-5.0141110054193505e+66, b:-5.0141110054193505e+66
iteration: 100 loss: 1.7286223665339186e+265, w:-7.047143783692584e+131, b:-7.047143783692584e+131
iteration: 150 loss: inf, w:-9.904494626138306e+196, b:-9.904494626138306e+196
iteration: 200 loss: inf, w:-1.3920393397706614e+262, b:-1.3920393397706614e+262
iteration: 250 loss: nan, w:nan, b:nan
iteration: 300 loss: nan, w:nan, b:nan
iteration: 350 loss: nan, w:nan, b:nan
iteration: 400 loss: nan, w:nan, b:nan
iteration: 450 loss: nan, w:nan, b:nan
************************************************************
Final loss: nan, w:nan, b:nan
可以看到, 在 Epoch 100多次, 后, 就已经收敛了. 当然正常来说, 应该给 loss 设置一个阈值的, 不然后面都 inf 了, 我还在 epoch, 这就有问题了. 这里就不改了, 总是习惯留下一些不完美, 这样才会记得更深. 其目的也是在与数理 ML 整个训练过程, 用入门级的 线性回归和 梯度下降法来整.
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通