线性回归
本文包含的内容有:
- 参数的求解
- 截距项的处理
- 脊回归
- 参数的计算技巧
- 贝叶斯角度的线性回归
引言
线性回归作为一个统计学里最基础的模型,广泛应用在各个场景下:如对房价的预测,销售额的预测。
基础的线性回归模型表达如下:
在应用中可以根据实际问题更改误差项
另外线性回归中的线性指的是
参数的求解
假设误差
无论是通过MLE的极大似然,还是通过投影的物理意义,都可以得到如下结果:
除了上述方法也可用对cost function应用梯度下降来做。
截距项的处理
通过对cost function的
因此我们在求解
可以看出,直线一定经过
因此拿到数据后可以先进行归一化处理,这样就可以求解参数
脊回归
因为极大似然容易导致过拟合,通过施加先验用最大后验估计的方法可以减缓过拟合。脊回归即对参数
计算技巧
式(2)和式(1)具有一定的相似性,通过一下变换(2)也可以写为(1)的形式,令:
可以发现,
因此
通过观察无论是极大似然估计还是最大后验估计,参数的求解式子中均包含求逆的运算,而求逆的运算是很麻烦的,应该尽可能避免。下面对求逆运算进行简化。
当
而
当
这样整个的计算复杂度为
贝叶斯角度的线性回归
虽然上述方法,例如脊回归中的高斯先验,也用了贝叶斯的思想,但在求解的时候没有求
假设已经知道了参数
如果要对
频率学派反对贝叶斯学派的一个重要观点就是先验的选择依赖主观判断。贝叶斯学派也对此也提出了解决方法,如无信息的先验或者经验贝叶斯。
下面放一个经验贝叶斯的代码,这部分的内容可以再"Machine Learning A Probability Perspective"的第五章、第七章找到。
# using utf-8
"""
所建立的模型为p(y) = N(x'w, 4),采用EB方法估计w
设置w先验为参数是μ,Σ的高斯分布
"""
import numpy as np
import matplotlib.pyplot as plt
Ns = 5 # 观测的样本数
np.random.seed(2)
x = np.random.randn(Ns, 1) * 5 # 观测点x值
error = np.random.randn(Ns, 1) # 误差
y_train = (x - 4) ** 2 + 2 * error # y观测值,观测的方差为4
y_true = (x - 4) ** 2 # y真实值
# 先对两个变量的进行设置
# 减去均值忽略掉截距项
x_expand = np.hstack((x, x**2))
x_expand_center = x_expand - x_expand.mean(axis=0)
y_train_center = y_train - y_train.mean(axis=0)
error_var = 4
"""
估计μ和Σ的取值
p(D|μ,Σ) = ∫ p(D|w) p(w|μ,Σ) dw
p(D|w) = N(y_mean|x_mean w,σ**2/N)
p(D|μ,Σ) = N(y_mean|x_mean u, σ**2/N + x_mean Σ x_mean')
x_mean μ = y_mean; μ = (x_mean' x_mean) ** -1 x_mean' y_mean
s ** 2 = σ**2/N + x_mean Σ x_mean'; Σ = (s ** 2 - σ ** 2 /N)(x_mean' x_mean) ** -1
"""
mu_prior = (np.linalg.pinv((x_expand_center.T.dot(x_expand_center))).dot(x_expand_center.T)).dot(y_train_center)
U, S, V_h = np.linalg.svd(x_expand_center, full_matrices=False)
V = V_h.T
mu_prior_svd = V.dot(np.diag(1 / S)).dot(U.T).dot(y_train_center)
print("按照伪逆计算的mu_prior", mu_prior)
print("按照SVD计算的mu_prior", mu_prior_svd)
sigma_prior = (y_train_center.var() - error_var) * np.linalg.pinv(x_expand_center.T.dot(x_expand_center))
sigma_prior_svd = (y_train_center.var() - error_var) * V.dot(np.diag(1/S**2)).dot(V.T)
print("按照伪逆计算的sigma_prior", sigma_prior)
print("按照SVD计算的sigma_prior", sigma_prior_svd)
"""
在了解先验μ和Σ后,对参数w进行估计
再利用对高斯线性模型
p(D|w) = N(y_mean| x_mean w, σ ** 2/N) p(w) = N(w|μ,Σ)
w_sigma = (Σ ** -1 + x_mean' (σ ** 2/N)**-1 x_mean) ** -1
w_mu = w_sigma(Σ **-1 μ + x_mean' (σ ** 2/N) ** -1 y_mean)
p(w|D) = N(w|, (σ**2)**-1+x_mean' Σ ** {-1} x_mean)
"""
w_sigma = np.linalg.pinv(np.linalg.pinv(sigma_prior) +
1/ error_var * x_expand_center.T.dot(x_expand_center))
w_sigma_svd = V.dot(np.diag(1/S **2)).dot(V.T) * (y_train_center.var() - error_var) * \
error_var / y_train_center.var()
print("按照伪逆计算的参数协方差后验:", w_sigma)
print("按照svd计算的参数协方差后验:", w_sigma_svd)
w_mu = w_sigma.dot(np.linalg.pinv(sigma_prior).dot(mu_prior) + 1/ error_var * x_expand_center.T.dot(y_train_center))
w_mu_svd = w_sigma_svd.dot(V.dot(np.diag(1/S**2)).dot(V.T).dot(mu_prior_svd) /
(y_train_center.var() - error_var) +
1 / error_var * x_expand_center.T.dot(y_train_center))
print("按照伪逆计算的参数均值差后验:", w_mu)
print("按照svd计算的参数均值差后验:", w_mu_svd)
w_0 = y_train_center.mean() - x_expand_center.mean(axis=0).dot(w_mu_svd)
y_predict = x_expand.dot(w_mu_svd) + w_0
"""
预测y
p(y|D) ∝ p(w|D)p(y|w)
再次利用高斯线性系统公式得到
p(y|D) = N(y|x w_mu, error_var + x w_wigma x')
"""
x_test = np.linspace(-30, 30, 24)[:, np.newaxis]
x_test_expand = np.hstack((x_test, x_test **2))
y_test = (x_test - 4) ** 2
y_predict_mean = x_test_expand.dot(w_mu_svd) +w_0
y_predict_std = np.sqrt(error_var + np.diag(x_test_expand.dot(w_sigma_svd.dot(x_test_expand.T))))[:, np.newaxis]
figure = plt.figure(1)
plt.plot(x_test.ravel(), y_test.ravel(), label="True")
plt.plot(x_test.ravel(), y_predict_mean.ravel(), label="predict")
plt.fill_between(x_test.ravel(), (y_predict_mean - 2 * y_predict_std).ravel(),
(y_predict_mean + 2*y_predict_std).ravel(), color="gray", alpha=0.2)
plt.scatter(x.ravel(), y_predict.ravel(), color="blue" ,label="train")
plt.legend(loc="best")
plt.xlabel("X")
plt.ylabel("y")
plt.title("EB for linear regression")
plt.show()
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· .NET10 - 预览版1新功能体验(一)