线性回归算法(Java)
手动实现线性回归(梯度下降法)
1 public class LinearRegressionGD {
2 private double learningRate;
3 private int iterations;
4 private double slope;
5 private double intercept;
6
7 public LinearRegressionGD(double learningRate, int iterations) {
8 this.learningRate = learningRate;
9 this.iterations = iterations;
10 this.slope = 0;
11 this.intercept = 0;
12 }
13
14 public void fit(double[] x, double[] y) {
15 int n = x.length;
16
17 for (int i = 0; i < iterations; i++) {
18 double slopeGradient = 0;
19 double interceptGradient = 0;
20
21 // 计算梯度
22 for (int j = 0; j < n; j++) {
23 double prediction = slope * x[j] + intercept;
24 slopeGradient += - (2.0 / n) * x[j] * (y[j] - prediction);
25 interceptGradient += - (2.0 / n) * (y[j] - prediction);
26 }
27
28 // 更新参数
29 slope -= learningRate * slopeGradient;
30 intercept -= learningRate * interceptGradient;
31 }
32 }
33
34 public double predict(double x) {
35 return slope * x + intercept;
36 }
37
38 public static void main(String[] args) {
39 double[] x = {1, 2, 3, 4, 5};
40 double[] y = {50, 55, 65, 70, 85};
41
42 LinearRegressionGD model = new LinearRegressionGD(0.01, 1000); // 设置学习率和迭代次数
43 model.fit(x, y);
44
45 double predictedScore = model.predict(6); // 预测学习时间为6小时时的考试成绩
46 System.out.println("预测的考试成绩: " + predictedScore);
47 }
48 }
代码解析
- 构造函数:初始化学习率和迭代次数。
- fit 方法:使用梯度下降法更新斜率和截距。predict 方法:根据学习到的斜率和截距进行预测。
- 计算当前预测值与实际值之间的误差,并根据误差计算梯度。
- 通过梯度调整斜率和截距。
- main 方法:创建数据集,实例化模型,训练模型,并进行预测。
参数设置
learningRate
:控制每次更新的步长。较小的学习率可能导致收敛速度慢,而较大的学习率可能导致不收敛。iterations
:指定训练模型的轮数。过少的迭代可能导致模型未收敛,过多的迭代则可能导致过拟合。
通过这种手动实现的方式,你可以灵活地控制学习率和迭代次数,以优化模型的性能。