线性回归算法(Java)

 

 

手动实现线性回归(梯度下降法)

 1 public class LinearRegressionGD {
 2     private double learningRate;
 3     private int iterations;
 4     private double slope;
 5     private double intercept;
 6 
 7     public LinearRegressionGD(double learningRate, int iterations) {
 8         this.learningRate = learningRate;
 9         this.iterations = iterations;
10         this.slope = 0;
11         this.intercept = 0;
12     }
13 
14     public void fit(double[] x, double[] y) {
15         int n = x.length;
16 
17         for (int i = 0; i < iterations; i++) {
18             double slopeGradient = 0;
19             double interceptGradient = 0;
20 
21             // 计算梯度
22             for (int j = 0; j < n; j++) {
23                 double prediction = slope * x[j] + intercept;
24                 slopeGradient += - (2.0 / n) * x[j] * (y[j] - prediction);
25                 interceptGradient += - (2.0 / n) * (y[j] - prediction);
26             }
27 
28             // 更新参数
29             slope -= learningRate * slopeGradient;
30             intercept -= learningRate * interceptGradient;
31         }
32     }
33 
34     public double predict(double x) {
35         return slope * x + intercept;
36     }
37 
38     public static void main(String[] args) {
39         double[] x = {1, 2, 3, 4, 5};
40         double[] y = {50, 55, 65, 70, 85};
41 
42         LinearRegressionGD model = new LinearRegressionGD(0.01, 1000); // 设置学习率和迭代次数
43         model.fit(x, y);
44 
45         double predictedScore = model.predict(6); // 预测学习时间为6小时时的考试成绩
46         System.out.println("预测的考试成绩: " + predictedScore);
47     }
48 }

代码解析

  1. 构造函数:初始化学习率和迭代次数。
  2. fit 方法:使用梯度下降法更新斜率和截距。predict 方法:根据学习到的斜率和截距进行预测。
    • 计算当前预测值与实际值之间的误差,并根据误差计算梯度。
    • 通过梯度调整斜率和截距。
  3. main 方法:创建数据集,实例化模型,训练模型,并进行预测。

参数设置

  • learningRate:控制每次更新的步长。较小的学习率可能导致收敛速度慢,而较大的学习率可能导致不收敛。
  • iterations:指定训练模型的轮数。过少的迭代可能导致模型未收敛,过多的迭代则可能导致过拟合。

通过这种手动实现的方式,你可以灵活地控制学习率和迭代次数,以优化模型的性能。

posted @ 2024-10-11 14:01  琅琊甲乙木  阅读(82)  评论(0编辑  收藏  举报