列文伯格-马夸尔特拟合算法(Levenberg Marquardt Fitting)的C#实现

前段时间要用到这个拟合算法,于是在网上搜了下发现了一段matlab代码,写的非常详细,马上改写之。假设我要拟合的函数为a * atan(b * x + c) + d,代码如下:

  1 using Accord.Math;
2 public class LevenbergMarquardtFitting
3 {
4 /// <summary>
5 /// LM拟合,函数形式:a * Atan(b * x + c) + d;
6 /// </summary>
7 /// <param name="data">自变量x</param>
8 /// <param name="obs">观测值</param>
9 /// <param name="a_est">估计的a</param>
10 /// <param name="b_est">估计的b</param>
11 /// <param name="c_est">估计的c</param>
12 /// <param name="d_est">估计的d</param>
13 /// <returns>相似度</returns>
14 public static double Regression(double[] data, double[] obs, out double a_est, out double b_est, out double c_est, out double d_est)
15 {
16 double r = -1;
17 if (data.Length != obs.Length)
18 {
19 throw new Exception("Observation Data and Correspaned Data are not the same dimention!");
20 }
21 // 数据个数
22 int Ndata = obs.Length;
23
24 // 参数维数
25 int Nparams = 4;
26 // 迭代最大次数
27 int n_iters = 1000;
28 // LM算法的阻尼系数初值
29 double lamda = 0.01;
30
31 // 定义雅克比矩阵
32 double[,] J = new double[Ndata,Nparams];
33
34 // 定义海塞矩阵
35 double[,] H = new double[Nparams, Nparams];
36 double[,] H_lm = new double[Nparams, Nparams];
37 double[,] invH_lm = new double[Nparams, Nparams];
38
39 // 定义误差向量
40 double[] ev = new double[Ndata];
41 double[] ev_lm = new double[Ndata];
42
43 // 定义中间变量
44 double[,] g = new double[Nparams, 1];
45 double[] dp = new double[Nparams];
46
47 // 初始猜测s
48 double a0=0.434, b0=1.24e-02, c0=0.942, d0=0.685;
49
50 double[] y_init = new double[Ndata];
51 double[] y_est = new double[Ndata];
52 double y_est_lm;
53
54 for (int i = 0; i < Ndata; i++)
55 y_init[i] = function(a0, b0, c0, d0, data[i]);
56
57 // step1: 变量赋值
58 int updateJ=1;
59 a_est=a0;
60 b_est=b0;
61 c_est=c0;
62 d_est=d0;
63 double linX, e = 0;
64 double a_lm, b_lm, c_lm, d_lm, e_lm;
65
66 // step2: 迭代
67 for (int it = 0; it < n_iters; it++)
68 {
69 if (updateJ == 1)
70 {
71 //根据当前估计值,计算雅克比矩阵
72 for (int i = 0; i < data.Length; i++)
73 {
74 linX = b_est * data[i] + c_est;
75 J[i, 0] = Math.Atan(linX);
76 J[i, 1] = a_est * data[i] / (linX * linX + 1);
77 J[i, 2] = a_est / (linX * linX + 1);
78 J[i, 3] = 1;
79 // 根据当前参数,得到函数值
80 y_est[i] = function(a_est, b_est, c_est, d_est, data[i]);
81 // 计算误差
82 ev[i] = obs[i] - y_est[i];
83 }
84 // 计算(拟)海塞矩阵
85 H = Matrix.TransposeAndMultiply(J, J);
86
87 // 若是第一次迭代,计算误差
88 if (it == 0)
89 e = Matrix.InnerProduct(ev, ev);
90 }
91 // 根据阻尼系数lamda混合得到H矩阵
92 Array.Copy(H, H_lm, H.Length);
93 for (int ii = 0; ii < Nparams; ii++)
94 H_lm[ii, ii] = H[ii, ii] + lamda;
95 // 计算步长dp,并根据步长计算新的可能的\参数估计值
96 double det = H_lm.Determinant();
97 if (det < 1.0e-10)
98 {
99 break;
100 }
101 else
102 {
103 invH_lm = H_lm.Inverse(false);
104 }
105
106 J.TransposeAndMultiply(ev, dp);
107 dp = invH_lm.Multiply(dp);
108
109 a_lm = a_est + dp[0];
110 b_lm = b_est + dp[1];
111 c_lm = c_est + dp[2];
112 d_lm = d_est + dp[3];
113
114 // 计算新的可能估计值对应的y和计算残差e
115 for (int i = 0; i < data.Length; i++)
116 {
117 y_est_lm = function(a_lm, b_lm, c_lm, d_lm, data[i]);
118 ev_lm[i] = obs[i] - y_est_lm;
119 }
120 e_lm = ev_lm.InnerProduct(ev_lm);
121 //根据误差,决定如何更新参数和阻尼系数
122 if (e_lm < e)
123 {
124 lamda /= 10.0f;
125 a_est = a_lm;
126 b_est = b_lm;
127 c_est = c_lm;
128 d_est = d_lm;
129 e = e_lm;
130 updateJ = 1;
131 }
132 else
133 {
134 updateJ = 0;
135 lamda = lamda * 10.0f;
136 }
137 }
138 r = corr_vect(y_est, obs);
139 return r;
140 }
141
142 //定义函数格式
143 private static double function(double a, double b, double c, double d, double x)
144 {
145 return a * Math.Atan(b * x + c) + d;
146 }
147
148 // 计算向量的相关系数
149 private static double corr_vect(double[] x, double[] y)
150 {
151 int n = x.Length;
152 double sumX = 0, sumY = 0, sumX2 = 0, sumY2 = 0, sumXY = 0;
153 for (int i = 0; i < n; i++)
154 {
155 sumX += x[i];
156 sumY += y[i];
157 sumX2 += x[i] * x[i];
158 sumY2 += y[i] * y[i];
159 sumXY += x[i] * y[i];
160 }
161 double lxx = sumX2 - sumX * sumX / n;
162 double lyy = sumY2 - sumY * sumY / n;
163 double lxy = sumXY - sumX * sumY / n;
164
165 return (lxy * lxy / (lxx * lyy));
166 }
167 }

 

,其中用到了机器学习库Accord.NET:http://accord-net.origo.ethz.ch/,主要是偷懒,不想编写矩阵运算方法,用其他的矩阵运算库也可以的。另外,函数应该也能改成其他自定义的。

posted @ 2011-12-08 23:48  zlalex  阅读(4199)  评论(1编辑  收藏  举报