一元线性回归分析及java实现

http://blog.csdn.net/hwwn2009/article/details/38414911

 

一元线性回归分析及java实现

 分类:

一元线性回归分析是处理两个变量之间关系的最简单模型,它所研究的对象是两个变量之间的线性相关关系。通过对这个模型的讨论,我们不仅可以掌握有关一元线性回归的知识,而且可以从中了解回归分析方法的基本思想、方法和应用。

 一、问题的提出

 例2-1-1  为了研究氮含量对铁合金溶液初生奥氏体析出温度的影响,测定了不同氮含量时铁合金溶液初生奥氏体析出温度,得到表2-1-1给出的5组数据。

表2-1-1   氮含量与灰铸铁初生奥氏体析出温度测试数据

    如果把氮含量作为横坐标,把初生奥氏体析出温度作为纵坐标,将这些数据标在平面直角坐标上,则得图2-1-1,这个图称为散点图。

从图2-1-1可以看出,数据点基本落在一条直线附近。这告诉我们,变量X与Y的关系大致可看作是线性关系,即它们之间的相互关系可以用线性关系来描述。但是由于并非所有的数据点完全落在一条直线上,因此X与Y的关系并没有确切到可以唯一地由一个X值确定一个Y值的程度。其它因素,诸如其它微量元素的含量以及测试误差等都会影响Y的测试结果。如果我们要研究X与Y的关系,可以作线性拟合

           (2-1-1)

 我们称(2-1-1)式为回归方程,a与b是待定常数,称为回归系数。从理论上讲,(2-1-1)式有无穷多组解,回归分析的任务是求出其最佳的线性拟合。

 二、最小二乘法原理

 如果把用回归方程 计算得到的 i值(i=1,2,…n)称为回归值,那么实际测量值yi与回归值 i之间存在着偏差,我们把这种偏差称为残差,记为ei(i=1,2,3,…,n)。这样,我们就可以用残差平方和来度量测量值与回归直线的接近或偏差程度。残差平方和定义为:

     (2-1-2)

所谓最小二乘法,就是选择a和b使Q(a,b)最小,即用最小二乘法得到的回归直线 是在所有直线中与测量值残差平方和Q最小的一条。由(2-1-2)式可知Q是关于a,b的二次函数,所以它的最小值总是存在的。下面讨论的a和b的求法。

三、正规方程组

根据微分中求极值的方法可知,Q(a,b)取得最小值应满足

                                (2-1-3)

由(2-1-2)式,并考虑上述条件,则

            (2-1-4)

(2-1-4)式称为正规方程组。解这一方程组可得

                       (2-1-5)

   其中

                       (2-1-6)

   (2-1-7)

    式中,Lxy称为xy的协方差之和,Lxx称为x的平方差之和。

如果改写(2-1-1)式,可得

                      (2-1-8)

    或

                        (2-1-9)

 由此可见,回归直线是通过点 的,即通过由所有实验测量值的平均值组成的点。从力学观点看, 即是N个散点 的重心位置。

 现在我们来建立关于例1的回归关系式。将表2-1-1的结果代入(2-1-5)式至(2-1-7)式,得出

a=1231.65

b=-2236.63

 因此,在例1中灰铸铁初生奥氏体析出温度(y)与氮含量(x)的回归关系式为

y=1231.65-2236.63x

 

 四、一元线性回归的统计学原理

 如果X和Y都是相关的随机变量,在确定x的条件下,对应的y值并不确定,而是形成一个分布。当X取确定的值时,Y的数学期望值也就确定了,因此Y的数学期望是x的函数,即

E(Y|X=x)=f(x)                  (2-1-10)

 这里方程f(x)称为Y对X的回归方程。如果回归方程是线性的,则

E(Y|X=x)=α+βx                (2-1-11)

 或

Y=α+βx+ε                    (2-1-12)

 其中

     ε―随机误差

 从样本中我们只能得到关于特征数的估计,并不能精确地求出特征数。因此只能用f(x)的估计式   来取代(2-1-11)式,用参数a和b分别作为α和β的估计量。那么,这两个估计量是否能够满足要求呢?

 

 1. 无偏性

 把(x,y)的n组观测值作为一个样本,由样本只能得到总体参数α和β的估计值。可以证明,当满足下列条件:

 (1)(xi,yi)是n个相互独立的观测值

 (2)εi是服从 分布的随机变量

 则由最小二乘法得到的a与b分别是总体参数α和β的无偏估计,即

E(a)= α

E(b)=β

    由此可推知

E( )=E(y)

    即y是回归值 在某点的数学期望值。

 2. a和b的方差

 可以证明,当n组观测值(xi,yi)相互独立,并且D(yi)=σ2,时,a和b的方差为

                               (2-1-13)

                  (2-1-14)

以上两式表明,a和b的方差均与xi的变动有关,xi分布越宽,则a和b的方差越小。另外a的方差还与观测点的数量有关,数据越多,a的方差越小。因此,为提高估计量的准确性,xi的分布应尽量宽,观测点数量应尽量多。

 

java实现

 

1、定义一个DataPoint类,对X和Y坐标点进行封装:

 

[java] view plaincopy在CODE上查看代码片派生到我的代码片
 
 
  1. /** 
  2.  * File        : DataPoint.java 
  3.  * Author      : zhouyujie 
  4.  * Date        : 2012-01-11 16:00:00 
  5.  * Description : Java实现一元线性回归的算法,座标点实体类,(可实现统计指标的预测) 
  6.  */  
  7. package com.zyujie.dm;  
  8.   
  9. public class DataPoint {  
  10.   
  11.     /** the x value */  
  12.     public float x;  
  13.   
  14.     /** the y value */  
  15.     public float y;  
  16.   
  17.     /** 
  18.      * Constructor. 
  19.      *  
  20.      * @param x 
  21.      *            the x value 
  22.      * @param y 
  23.      *            the y value 
  24.      */  
  25.     public DataPoint(float x, float y) {  
  26.         this.x = x;  
  27.         this.y = y;  
  28.     }  
  29. }  
2、下面是算法实现回归线:

 

 

[java] view plaincopy在CODE上查看代码片派生到我的代码片
 
 
  1. /** 
  2.  * File        : DataPoint.java 
  3.  * Author      : zhouyujie 
  4.  * Date        : 2012-01-11 16:00:00 
  5.  * Description : Java实现一元线性回归的算法,回归线实现类,(可实现统计指标的预测) 
  6.  */  
  7. package com.zyujie.dm;  
  8.   
  9. import java.math.BigDecimal;  
  10. import java.util.ArrayList;  
  11.   
  12. public class RegressionLine // implements Evaluatable  
  13. {  
  14.     /** sum of x */  
  15.     private double sumX;  
  16.   
  17.     /** sum of y */  
  18.     private double sumY;  
  19.   
  20.     /** sum of x*x */  
  21.     private double sumXX;  
  22.   
  23.     /** sum of x*y */  
  24.     private double sumXY;  
  25.   
  26.     /** sum of y*y */  
  27.     private double sumYY;  
  28.   
  29.     /** sum of yi-y */  
  30.     private double sumDeltaY;  
  31.   
  32.     /** sum of sumDeltaY^2 */  
  33.     private double sumDeltaY2;  
  34.   
  35.     /** 误差 */  
  36.     private double sse;  
  37.   
  38.     private double sst;  
  39.   
  40.     private double E;  
  41.   
  42.     private String[] xy;  
  43.   
  44.     private ArrayList listX;  
  45.   
  46.     private ArrayList listY;  
  47.   
  48.     private int XMin, XMax, YMin, YMax;  
  49.   
  50.     /** line coefficient a0 */  
  51.     private float a0;  
  52.   
  53.     /** line coefficient a1 */  
  54.     private float a1;  
  55.   
  56.     /** number of data points */  
  57.     private int pn;  
  58.   
  59.     /** true if coefficients valid */  
  60.     private boolean coefsValid;  
  61.   
  62.     /** 
  63.      * Constructor. 
  64.      */  
  65.     public RegressionLine() {  
  66.         XMax = 0;  
  67.         YMax = 0;  
  68.         pn = 0;  
  69.         xy = new String[2];  
  70.         listX = new ArrayList();  
  71.         listY = new ArrayList();  
  72.     }  
  73.   
  74.     /** 
  75.      * Constructor. 
  76.      *  
  77.      * @param data 
  78.      *            the array of data points 
  79.      */  
  80.     public RegressionLine(DataPoint data[]) {  
  81.         pn = 0;  
  82.         xy = new String[2];  
  83.         listX = new ArrayList();  
  84.         listY = new ArrayList();  
  85.         for (int i = 0; i < data.length; ++i) {  
  86.             addDataPoint(data[i]);  
  87.         }  
  88.     }  
  89.   
  90.     /** 
  91.      * Return the current number of data points. 
  92.      *  
  93.      * @return the count 
  94.      */  
  95.     public int getDataPointCount() {  
  96.         return pn;  
  97.     }  
  98.   
  99.     /** 
  100.      * Return the coefficient a0. 
  101.      *  
  102.      * @return the value of a0 
  103.      */  
  104.     public float getA0() {  
  105.         validateCoefficients();  
  106.         return a0;  
  107.     }  
  108.   
  109.     /** 
  110.      * Return the coefficient a1. 
  111.      *  
  112.      * @return the value of a1 
  113.      */  
  114.     public float getA1() {  
  115.         validateCoefficients();  
  116.         return a1;  
  117.     }  
  118.   
  119.     /** 
  120.      * Return the sum of the x values. 
  121.      *  
  122.      * @return the sum 
  123.      */  
  124.     public double getSumX() {  
  125.         return sumX;  
  126.     }  
  127.   
  128.     /** 
  129.      * Return the sum of the y values. 
  130.      *  
  131.      * @return the sum 
  132.      */  
  133.     public double getSumY() {  
  134.         return sumY;  
  135.     }  
  136.   
  137.     /** 
  138.      * Return the sum of the x*x values. 
  139.      *  
  140.      * @return the sum 
  141.      */  
  142.     public double getSumXX() {  
  143.         return sumXX;  
  144.     }  
  145.   
  146.     /** 
  147.      * Return the sum of the x*y values. 
  148.      *  
  149.      * @return the sum 
  150.      */  
  151.     public double getSumXY() {  
  152.         return sumXY;  
  153.     }  
  154.   
  155.     public double getSumYY() {  
  156.         return sumYY;  
  157.     }  
  158.   
  159.     public int getXMin() {  
  160.         return XMin;  
  161.     }  
  162.   
  163.     public int getXMax() {  
  164.         return XMax;  
  165.     }  
  166.   
  167.     public int getYMin() {  
  168.         return YMin;  
  169.     }  
  170.   
  171.     public int getYMax() {  
  172.         return YMax;  
  173.     }  
  174.   
  175.     /** 
  176.      * Add a new data point: Update the sums. 
  177.      *  
  178.      * @param dataPoint 
  179.      *            the new data point 
  180.      */  
  181.     public void addDataPoint(DataPoint dataPoint) {  
  182.         sumX += dataPoint.x;  
  183.         sumY += dataPoint.y;  
  184.         sumXX += dataPoint.x * dataPoint.x;  
  185.         sumXY += dataPoint.x * dataPoint.y;  
  186.         sumYY += dataPoint.y * dataPoint.y;  
  187.   
  188.         if (dataPoint.x > XMax) {  
  189.             XMax = (int) dataPoint.x;  
  190.         }  
  191.         if (dataPoint.y > YMax) {  
  192.             YMax = (int) dataPoint.y;  
  193.         }  
  194.   
  195.         // 把每个点的具体坐标存入ArrayList中,备用  
  196.   
  197.         xy[0] = (int) dataPoint.x + "";  
  198.         xy[1] = (int) dataPoint.y + "";  
  199.         if (dataPoint.x != 0 && dataPoint.y != 0) {  
  200.             System.out.print(xy[0] + ",");  
  201.             System.out.println(xy[1]);  
  202.   
  203.             try {  
  204.                 // System.out.println("n:"+n);  
  205.                 listX.add(pn, xy[0]);  
  206.                 listY.add(pn, xy[1]);  
  207.             } catch (Exception e) {  
  208.                 e.printStackTrace();  
  209.             }  
  210.   
  211.             /* 
  212.              * System.out.println("N:" + n); System.out.println("ArrayList 
  213.              * listX:"+ listX.get(n)); System.out.println("ArrayList listY:"+ 
  214.              * listY.get(n)); 
  215.              */  
  216.         }  
  217.         ++pn;  
  218.         coefsValid = false;  
  219.     }  
  220.   
  221.     /** 
  222.      * Return the value of the regression line function at x. (Implementation of 
  223.      * Evaluatable.) 
  224.      *  
  225.      * @param x 
  226.      *            the value of x 
  227.      * @return the value of the function at x 
  228.      */  
  229.     public float at(int x) {  
  230.         if (pn < 2)  
  231.             return Float.NaN;  
  232.   
  233.         validateCoefficients();  
  234.         return a0 + a1 * x;  
  235.     }  
  236.   
  237.     /** 
  238.      * Reset. 
  239.      */  
  240.     public void reset() {  
  241.         pn = 0;  
  242.         sumX = sumY = sumXX = sumXY = 0;  
  243.         coefsValid = false;  
  244.     }  
  245.   
  246.     /** 
  247.      * Validate the coefficients. 计算方程系数 y=ax+b 中的a 
  248.      */  
  249.     private void validateCoefficients() {  
  250.         if (coefsValid)  
  251.             return;  
  252.   
  253.         if (pn >= 2) {  
  254.             float xBar = (float) sumX / pn;  
  255.             float yBar = (float) sumY / pn;  
  256.   
  257.             a1 = (float) ((pn * sumXY - sumX * sumY) / (pn * sumXX - sumX  
  258.                     * sumX));  
  259.             a0 = (float) (yBar - a1 * xBar);  
  260.         } else {  
  261.             a0 = a1 = Float.NaN;  
  262.         }  
  263.   
  264.         coefsValid = true;  
  265.     }  
  266.   
  267.     /** 
  268.      * 返回误差 
  269.      */  
  270.     public double getR() {  
  271.         // 遍历这个list并计算分母  
  272.         for (int i = 0; i < pn - 1; i++) {  
  273.             float Yi = (float) Integer.parseInt(listY.get(i).toString());  
  274.             float Y = at(Integer.parseInt(listX.get(i).toString()));  
  275.             float deltaY = Yi - Y;  
  276.             float deltaY2 = deltaY * deltaY;  
  277.             /* 
  278.              * System.out.println("Yi:" + Yi); System.out.println("Y:" + Y); 
  279.              * System.out.println("deltaY:" + deltaY); 
  280.              * System.out.println("deltaY2:" + deltaY2); 
  281.              */  
  282.   
  283.             sumDeltaY2 += deltaY2;  
  284.             // System.out.println("sumDeltaY2:" + sumDeltaY2);  
  285.   
  286.         }  
  287.   
  288.         sst = sumYY - (sumY * sumY) / pn;  
  289.         // System.out.println("sst:" + sst);  
  290.         E = 1 - sumDeltaY2 / sst;  
  291.   
  292.         return round(E, 4);  
  293.     }  
  294.   
  295.     // 用于实现精确的四舍五入  
  296.     public double round(double v, int scale) {  
  297.   
  298.         if (scale < 0) {  
  299.             throw new IllegalArgumentException(  
  300.                     "The scale must be a positive integer or zero");  
  301.         }  
  302.   
  303.         BigDecimal b = new BigDecimal(Double.toString(v));  
  304.         BigDecimal one = new BigDecimal("1");  
  305.         return b.divide(one, scale, BigDecimal.ROUND_HALF_UP).doubleValue();  
  306.   
  307.     }  
  308.   
  309.     public float round(float v, int scale) {  
  310.   
  311.         if (scale < 0) {  
  312.             throw new IllegalArgumentException(  
  313.                     "The scale must be a positive integer or zero");  
  314.         }  
  315.   
  316.         BigDecimal b = new BigDecimal(Double.toString(v));  
  317.         BigDecimal one = new BigDecimal("1");  
  318.         return b.divide(one, scale, BigDecimal.ROUND_HALF_UP).floatValue();  
  319.   
  320.     }  
  321. }  
3、线性回归测试类:

 

 

[java] view plaincopy在CODE上查看代码片派生到我的代码片
 
 
  1. /** 
  2.  * File        : DataPoint.java 
  3.  * Author      : zhouyujie 
  4.  * Date        : 2012-01-11 16:00:00 
  5.  * Description : Java实现一元线性回归的算法,线性回归测试类,(可实现统计指标的预测) 
  6.  */  
  7. package com.zyujie.dm;  
  8.   
  9. /** 
  10.  * <p> 
  11.  * <b>Linear Regression</b> <br> 
  12.  * Demonstrate linear regression by constructing the regression line for a set 
  13.  * of data points. 
  14.  *  
  15.  * <p> 
  16.  * require DataPoint.java,RegressionLine.java 
  17.  *  
  18.  * <p> 
  19.  * 为了计算对于给定数据点的最小方差回线,需要计算SumX,SumY,SumXX,SumXY; (注:SumXX = Sum (X^2)) 
  20.  * <p> 
  21.  * <b>回归直线方程如下: f(x)=a1x+a0 </b> 
  22.  * <p> 
  23.  * <b>斜率和截距的计算公式如下:</b> <br> 
  24.  * n: 数据点个数 
  25.  * <p> 
  26.  * a1=(n(SumXY)-SumX*SumY)/(n*SumXX-(SumX)^2) <br> 
  27.  * a0=(SumY - SumY * a1)/n <br> 
  28.  * (也可表达为a0=averageY-a1*averageX) 
  29.  *  
  30.  * <p> 
  31.  * <b>画线的原理:两点成一直线,只要能确定两个点即可</b><br> 
  32.  * 第一点:(0,a0) 再随意取一个x1值代入方程,取得y1,连结(0,a0)和(x1,y1)两点即可。 
  33.  * 为了让线穿过整个图,x1可以取横坐标的最大值Xmax,即两点为(0,a0),(Xmax,Y)。如果y=a1*Xmax+a0,y大于 
  34.  * 纵坐标最大值Ymax,则不用这个点。改用y取最大值Ymax,算得此时x的值,使用(X,Ymax), 即两点为(0,a0),(X,Ymax) 
  35.  *  
  36.  * <p> 
  37.  * <b>拟合度计算:(即Excel中的R^2)</b> 
  38.  * <p> 
  39.  * *R2 = 1 - E 
  40.  * <p> 
  41.  * 误差E的计算:E = SSE/SST 
  42.  * <p> 
  43.  * SSE=sum((Yi-Y)^2) SST=sumYY - (sumY*sumY)/n; 
  44.  * <p> 
  45.  */  
  46. public class LinearRegression {  
  47.   
  48.     private static final int MAX_POINTS = 10;  
  49.   
  50.     private double E;  
  51.   
  52.     /** 
  53.      * Main program. 
  54.      *  
  55.      * @param args 
  56.      *            the array of runtime arguments 
  57.      */  
  58.     public static void main(String args[]) {  
  59.         RegressionLine line = new RegressionLine();  
  60.   
  61.         line.addDataPoint(new DataPoint(1, 136));  
  62.         line.addDataPoint(new DataPoint(2, 143));  
  63.         line.addDataPoint(new DataPoint(3, 132));  
  64.         line.addDataPoint(new DataPoint(4, 142));  
  65.         line.addDataPoint(new DataPoint(5, 147));  
  66.   
  67.         printSums(line);  
  68.         printLine(line);  
  69.     }  
  70.   
  71.     /** 
  72.      * Print the computed sums. 
  73.      *  
  74.      * @param line 
  75.      *            the regression line 
  76.      */  
  77.     private static void printSums(RegressionLine line) {  
  78.         System.out.println("\n数据点个数 n = " + line.getDataPointCount());  
  79.         System.out.println("\nSum x  = " + line.getSumX());  
  80.         System.out.println("Sum y  = " + line.getSumY());  
  81.         System.out.println("Sum xx = " + line.getSumXX());  
  82.         System.out.println("Sum xy = " + line.getSumXY());  
  83.         System.out.println("Sum yy = " + line.getSumYY());  
  84.   
  85.     }  
  86.   
  87.     /** 
  88.      * Print the regression line function. 
  89.      *  
  90.      * @param line 
  91.      *            the regression line 
  92.      */  
  93.     private static void printLine(RegressionLine line) {  
  94.         System.out.println("\n回归线公式:  y = " + line.getA1() + "x + "  
  95.                 + line.getA0());  
  96.         System.out.println("误差:     R^2 = " + line.getR());  
  97.     }  
  98.       
  99.     //y = 2.1x + 133.7   2.1 * 6 + 133.7 = 12.6 + 133.7 = 146.3  
  100.     //y = 2.1x + 133.7   2.1 * 7 + 133.7 = 14.7 + 133.7 = 148.4  
  101.   
  102. }  

 

我们运行测试类,得到运行结果:

1,136
2,143
3,132
4,142
5,147

数据点个数 n = 5

Sum x  = 15.0
Sum y  = 700.0
Sum xx = 55.0
Sum xy = 2121.0
Sum yy = 98142.0

回归线公式:  y = 2.1x + 133.7
误差:     R^2 = 0.3658

假如某公司:

1月收入,136万元
2月收入,143万元
3月收入,132万元
4月收入,142万元
5月收入,147万元

我们可以根据回归线公式:y = 2.1x + 133.7,预测出6月份收入:

y = 2.1 * 6 + 133.7 = 12.6 + 133.7 = 146.3

posted @ 2016-08-17 03:37  donaldlee  阅读(1748)  评论(1编辑  收藏  举报