拟合

 

  1 //LeastSquaresFitter.h
  2 #pragma once
  3 #include <WTypes.h>
  4 #include <cmath>
  5 #include <vector>
  6 #include <map>
  7 using namespace std;
  8 
  9 //多项式 f(x) = c_0 + c_1*x^1 + c_2*x^2 + ... + c_N*x^N
 10 class CPolyFunction { 
 11 public:
 12     CPolyFunction() : m_dim(0), m_expr(){}
 13     //根据多项式系数向量构造多项式表达式 c_0 + ,c = {}
 14     CPolyFunction(vector<DOUBLE>& c) : m_dim(0), m_expr() 
 15     {
 16         INT32 len = c.size();
 17         for (INT32 i = 0; i < len; ++i) { 
 18             if (c[i]) { 
 19                 m_expr.insert(make_pair(i, c[i])); 
 20                 m_dim = i; 
 21             }
 22         } 
 23     }
 24     ~CPolyFunction(){}
 25     //输入x计算多项式表达式值y
 26      DOUBLE operator ()(DOUBLE x) 
 27      { 
 28          DOUBLE ans = m_expr.at(m_dim); 
 29          for (INT32 i = m_dim; i > 0; --i) { 
 30              if (m_expr.end() != m_expr.find(i)) { 
 31                  ans = m_expr[i - 1] + x * ans; 
 32              } 
 33          } 
 34          return ans; 
 35      } 
 36 private:
 37     INT32 m_dim;
 38     //多项式各项指数-系数映射
 39     map<INT32, DOUBLE> m_expr;
 40 };
 41 
 42 class CLeastSquaresFitter
 43 {
 44 public:
 45     CLeastSquaresFitter(void) : m_ssr(0), m_sse(0), m_rmse(0), m_dim(0) {}
 46     ~CLeastSquaresFitter(void){}
 47     CLeastSquaresFitter(vector<DOUBLE>& X, 
 48         vector<DOUBLE>& Y) : m_ssr(0), m_sse(0), m_rmse(0), m_dim(0), m_X(X), m_Y(Y){}
 49     //返回拟合多项式的系数
 50     vector<DOUBLE> GetCoef()
 51     {
 52         return m_coef;
 53     }
 54     //返回与观测数据向量X对应的拟合数据向量FitY
 55     vector<DOUBLE> GetFitY()
 56     {
 57         return m_FitY;
 58     }
 59     //返回拟合多项式的阶数
 60     INT32 GetPolyFitDim()
 61     {
 62         return m_dim;
 63     }
 64     //返回拟合数据的回归平方和
 65     DOUBLE GetSSR()
 66     {
 67         return m_ssr;
 68     }
 69     //剩余平方和
 70     DOUBLE GetSSE()
 71     { 
 72         return m_sse;
 73     }
 74     //RMSE均方根误差
 75     DOUBLE GetRMSE()
 76     {
 77         return m_rmse;
 78     }
 79     //返回卡方统计量chi-squared
 80     DOUBLE GetChisq()
 81     {
 82         return 1 - (m_sse / (m_ssr + m_sse));
 83     }
 84     //线性最小二乘拟合 y = ax + b
 85     BOOL LinearFit(BOOL isSaveFitY = TRUE)
 86     {
 87         m_coef.resize(2, 0);
 88         DOUBLE t1 = 0, t2 = 0, t3 = 0, t4 = 0;
 89         UINT32 length = m_X.size() < m_Y.size() ? m_X.size() : m_Y.size();
 90         for (UINT32 i = 0; i < length; ++i) {
 91             t1 += m_X[i] * m_X[i];
 92             t2 += m_X[i];
 93             t3 += m_X[i] * m_Y[i];
 94             t4 += m_Y[i];
 95         }
 96         m_coef[0] = ((t1 * t4 - t2 * t3) / (t1 * length - t2 * t2));
 97         m_coef[1] = ((t3 * length - t2 * t4) / (t1 * length - t2 * t2));
 98         Calculate(isSaveFitY);
 99         return TRUE;
100     }
101 
102     //多项式拟合,拟合多项式原型y(x) = c_0 + c_1*x^1 + c_2*x^2 + ... + c_N*x^N
103     //N即拟合多项式的阶
104     void PolyFit(UINT32 n = 1, BOOL isSaveFitY = TRUE)
105     {
106         UINT32 length = m_X.size() < m_Y.size() ? m_X.size() : m_Y.size();
107         m_coef.resize(n + 1, 0);
108         vector<DOUBLE> tpX(length, 1.0);
109         vector<DOUBLE> tpY(m_Y);
110         vector<DOUBLE> sumxx(n * 2 + 1);
111         vector<DOUBLE> ata((n + 1) * (n + 1));
112         vector<DOUBLE> sumxy(n + 1);
113         UINT32 i = 0, j = 0;
114         for (i = 0; i < 2 * n + 1; ++i) {
115             for (sumxx[i] = 0, j = 0; j < length; ++j) {
116                 sumxx[i] += tpX[j];
117                 tpX[j] *= m_X[j];
118             }
119         }
120 
121         for (i = 0; i < n + 1; ++i) {
122             for (sumxy[i] = 0, j = 0; j < length; ++j) {
123                 sumxy[i] += tpY[j];
124                 tpY[j] *= m_X[j];
125             }
126         }
127 
128         for (i = 0; i < n + 1; ++i) {
129             for (j = 0; j < n + 1; ++j) {
130                 ata[i * (n + 1) + j] = sumxx[i + j];
131             }
132         }
133 
134         GaussSolve(n + 1, ata, m_coef, sumxy);
135         Calculate(isSaveFitY);
136     }
137 
138 
139 protected:
140     //高斯求解方程组Ax=b
141     void GaussSolve(INT32 n, vector<DOUBLE>& A, vector<DOUBLE>& x, vector<DOUBLE>& b)
142     {
143         INT32 i, j, k, r;
144         DOUBLE max;
145         for (k = 0; k < n - 1; ++k) {
146             /*find maxmum*/
147             max = fabs(A[k * n + k]);
148             r = k;
149             for (i = k + 1; i < n - 1; ++i) {
150                 if (max < fabs(A[i * n + i])) {
151                     max = fabs(A[i * n + i]);
152                     r = i;
153                 }
154             }
155             if (r != k) {
156                 /*change array A[k] and A[r]*/
157                 for (i = 0; i < n; ++i) {
158                     max = A[k * n + i];
159                     A[k * n + i] = A[r * n + i];
160                     A[r * n + i] = max;
161                 }
162             }
163             /*change array b[k] and b[r]*/
164             max = b[k];
165             b[k] = b[r];
166             b[r] = max;
167             for (i = k + 1; i < n; ++i) {
168                 for (j = k + 1; j < n; ++j) {
169                     A[i * n + j] -= A[i * n + k] * A[k * n + j] / A[k * n + k];
170                 }
171                 b[i] -= A[i * n + k] * b[k] / A[k * n + k];
172             }
173         }
174 
175         for (i = n - 1; i >= 0; x[i] /= A[i * n + i], --i) {
176             for (j = i + 1, x[i] = b[i]; j < n; ++j) {
177                 x[i] -= A[i * n + j] * x[j];
178             }
179         }
180     }
181     //返回x[i]对应当拟合后的y值
182     DOUBLE GetY(const DOUBLE x) const
183     {
184         DOUBLE ans = 0;
185         for (UINT32 i = 0; i < m_coef.size(); ++i) {
186             ans += m_coef[i] * pow(x, i);
187         }
188         return ans;
189     }
190     //计算拟合后的Y值及产生的最小二乘意义上的误差
191     VOID Calculate(BOOL isSaveFitY = TRUE)
192     {
193         DOUBLE dbY = Mean(m_Y);
194         UINT32 length = m_X.size() < m_Y.size() ? m_X.size() : m_Y.size();
195         if (isSaveFitY) {
196             m_FitY.clear(); 
197             m_FitY.reserve(length);
198         }
199         for (UINT32 i = 0; i < length; ++i) {
200             DOUBLE yi = GetY(m_X[i]);
201             m_ssr += (yi - dbY) * (yi - dbY);
202             m_sse += (yi - m_Y[i]) * (yi - m_Y[i]);
203             if (isSaveFitY) {
204                 m_FitY.push_back(yi);
205             }
206         }
207         m_rmse = sqrt(m_sse / length);
208     }
209     //计算向量元素均值
210     DOUBLE Mean(const vector<DOUBLE>& v)
211     {
212         DOUBLE total(0);
213         UINT32 length = v.size();
214         for (UINT32 i = 0; i < v.size(); ++i) {
215             total += v[i];
216         }
217         return total / length;
218     }
219 
220 private:
221     vector<DOUBLE> m_X;//待拟合样本数据
222     vector<DOUBLE> m_Y;//待拟合样本数据
223     vector<DOUBLE> m_FitY;//最小二乘拟合后的Y值
224     vector<DOUBLE> m_coef;//最小二乘拟合多项式系数 y=c_0 + c_1*x^1 + c_2*x^2 + ... + c_n*x^n
225     INT32 m_dim;//最小二乘拟合多项式阶数
226     DOUBLE m_ssr;//回归平方和
227     DOUBLE m_sse;//剩余平方和
228     DOUBLE m_rmse;//RMSE均方根误差
229 };
 1 //demo.cpp
 2 #include <iostream>
 3 #include "LeastSquaresFitter.h"
 4 using namespace std;
 5 
 6 int main()
 7 {
 8     double years[] = {1970, 1980, 1990, 2000};
 9     double data[4] = {   12,   11,   14,   13 };
10     vector<DOUBLE> x(years, years + 4);
11     vector<DOUBLE> y(data, data + 4);
12 
13     CLeastSquaresFitter fiter(x, y);
14     fiter.LinearFit();
15     vector<DOUBLE> c = fiter.GetCoef();
16     for (vector<DOUBLE>::iterator it = c.begin(); it != c.end(); ++it)
17         cout << *it << " " ;
18     cout << endl;
19     fiter.PolyFit();
20     c = fiter.GetCoef();
21     for (vector<DOUBLE>::iterator it = c.begin(); it != c.end(); ++it)
22         cout << *it << " " ;
23     cout << endl;
24 
25     return 0;
26 }

 

posted on 2014-11-22 02:50  來時的路  阅读(331)  评论(0编辑  收藏  举报