手写高斯牛顿法代码

手写高斯牛顿法代码

目的:曲线拟合
曲线方程:\(y=exp(ax^2+bx+c)+w\)
已知:样本数据x,y
想要得到拟合曲线参数a,b,c
我们的实际小目标:求解增量方程得到ΔX(有了Δx就可以不停的迭代Eabc使得拟合Rabc)


三个步骤:
1、先根据模型生成x,y的真值,并在真值中添加高斯分布的噪声
2、使用高斯牛顿法进行迭代
3、求解高斯牛顿法的增量方程

第一步

首先定义真实的a,b,c参数值为ar,br,cr
这三个也是我们要拟合的最终目标,最终的拟合结果跟这三个值越接近越好
double ar = 1.0, br = 2.0, cr = 1.0; // 真实参数值

然后定义估计的初始参数值ae,be,ce,这三个是最终拟合的结果,最后与真实的a,b,c作比较
double ae = 2.0, be = -1.0, ce = 5.0; // 估计参数值

定义产生的数据个数和加入的噪音

int N = 200;                                 // 数据点
double w_sigma = 1.0;                        // 噪声Sigma值
double inv_sigma = 1.0 / w_sigma;
cv::RNG rng;                                 // OpenCV随机数产生器

根据曲线方程产生数据,x_data就是输入的数据,y_data就是把x带入到方程后得到的数

  vector<double> x_data, y_data;      // 数据
  for (int i = 0; i < N; i++) {
    double x = i / 200.0;
    x_data.push_back(x);
    y_data.push_back(exp(ar * x * x + br * x + cr) + rng.gaussian(w_sigma * w_sigma));
  }

第二步

数据准备好了之后就开始GN迭代
定义误差

double error = yi - exp(ae * xi * xi + be * xi + ce);

然后分别求出每个误差对于状态变量的导数,得出雅可比矩阵

Vector3d J; // 雅可比矩阵
J[0] = -xi * xi * exp(ae * xi * xi + be * xi + ce);  // de/da
J[1] = -xi * exp(ae * xi * xi + be * xi + ce);  // de/db
J[2] = -exp(ae * xi * xi + be * xi + ce);  // de/dc

然后求解H矩阵和偏置b

H += inv_sigma * inv_sigma * J * J.transpose();
b += -inv_sigma * inv_sigma * error * J;

第三步

求解GN的增量方程Hx=b
Vector3d dx = H.ldlt().solve(b);
结果增量dx是一个3x1的矩阵
然后更新ae,be,ce

ae += dx[0];
be += dx[1];
ce += dx[2];

最后更新损失:lastCost = cost;

通过判断损失的大小来决定是否结束迭代:

if (iter > 0 && cost >= lastCost) {
  cout << "cost: " << cost << ">= last cost: " << lastCost << ", break." << endl;
  cout<<"iter times: "<<iter<<endl;
  break;
}


完整代码:

#include <iostream>
#include <chrono>
#include <opencv2/opencv.hpp>
#include <Eigen/Core>
#include <Eigen/Dense>

using namespace std;
using namespace Eigen;

int main(int argc, char **argv) {
  double ar = 1.0, br = 2.0, cr = 1.0;         // 真实参数值
  double ae = 2.0, be = -3.0, ce = 2.0;        // 估计参数值
  int N = 100;                                 // 数据点
  double w_sigma = 1.0;                        // 噪声Sigma值
  double inv_sigma = 1.0 / w_sigma;
  cv::RNG rng;                                 // OpenCV随机数产生器

  vector<double> x_data, y_data;      // 数据
  for (int i = 0; i < N; i++) {
    double x = i / 100.0;
    x_data.push_back(x);
    y_data.push_back(exp(ar * x * x + br * x + cr) + rng.gaussian(w_sigma * w_sigma));
  }

  // 开始Gauss-Newton迭代
  int iterations = 100;    // 迭代次数
  double cost = 0, lastCost = 0;  // 本次迭代的cost和上一次迭代的cost

  chrono::steady_clock::time_point t1 = chrono::steady_clock::now();
  for (int iter = 0; iter < iterations; iter++) {

    Matrix3d H = Matrix3d::Zero();             // Hessian = J^T W^{-1} J in Gauss-Newton
    Vector3d b = Vector3d::Zero();             // bias
    cost = 0;

    for (int i = 0; i < N; i++) {
      double xi = x_data[i], yi = y_data[i];  // 第i个数据点
      double error = yi - exp(ae * xi * xi + be * xi + ce);
      Vector3d J; // 雅可比矩阵
      J[0] = -xi * xi * exp(ae * xi * xi + be * xi + ce);  // de/da
      J[1] = -xi * exp(ae * xi * xi + be * xi + ce);  // de/db
      J[2] = -exp(ae * xi * xi + be * xi + ce);  // de/dc

      H += inv_sigma * inv_sigma * J * J.transpose();
      b += -inv_sigma * inv_sigma * error * J;

      cost += error * error;
    }

    // 求解线性方程 Hx=b
    cout<<"a"<<endl;
    Vector3d dx = H.ldlt().solve(b);
    if (isnan(dx[0])) {
      cout << "result is nan!" << endl;
      break;
    }

    if (iter > 0 && cost >= lastCost) {
      cout << "cost: " << cost << ">= last cost: " << lastCost << ", break." << endl;
      break;
    }

    ae += dx[0];
    be += dx[1];
    ce += dx[2];

    lastCost = cost;

    cout << "total cost: " << cost << ", \t\tupdate: " << dx.transpose() <<
         "\t\testimated params: " << ae << "," << be << "," << ce << endl;
  }

  chrono::steady_clock::time_point t2 = chrono::steady_clock::now();
  chrono::duration<double> time_used = chrono::duration_cast<chrono::duration<double>>(t2 - t1);
  cout << "solve time cost = " << time_used.count() << " seconds. " << endl;

  cout << "estimated abc = " << ae << ", " << be << ", " << ce << endl;
  return 0;
}

posted @ 2022-03-02 15:35  学不会SLAM的  阅读(420)  评论(0编辑  收藏  举报