拟合高斯函数的梯度下降法例子

高斯函数也是一种常见的函数。拟合它可以通过求对数转换成线性规划问题,从而用最小二乘法拟合。不过为了精确一点,可以用最小二乘法拟合得到初始解之后再用梯度下降法求精。以下将描述高斯函数的梯度下降法公式推导过程。高斯函数的形式为:

$${y=a \cdot e^{- \frac{ \left ( x - b \right ) ^{2} } {2c ^{2}} } }$$

它的一般式为:

$${y=e^{ax^{2}+bx+c},\left(a<0\right)}$$

定义误差函数如下:

$${e=\sum_{i}^{}\left ( e^{ax_{i}^{2}+bx_{i}+c}-y_{i} \right )^{2},\left(a<0\right) }$$

分别对${a,b,c}$求偏导数得:

$${ \begin{align} \frac{\partial e}{\partial a}&=\sum_{i}^{}2 f_{i}\left ( a,b,c \right )g_{i}\left ( a,b,c \right )x_{i}^{2} \\ \frac{\partial e}{\partial b}&=\sum_{i}^{}2 f_{i}\left ( a,b,c \right )g_{i}\left ( a,b,c \right )x_{i} \\ \frac{\partial e}{\partial c}&=\sum_{i}^{}2 f_{i}\left ( a,b,c \right )g_{i}\left ( a,b,c \right ) \end{align} }$$

其中:

$${ \begin{align*} g_{i}\left ( a,b,c \right )&=e^{ax_{i}^{2}+bx_{i}+c} \\ f_{i}\left ( a,b,c \right )&=g_{i}\left ( a,b,c \right )-y_{i} \end{align*} }$$

最终迭代公式为:

$${ \left(a,b,c \right)_{n+1} = \left(a,b,c \right)_{n}- \left( \frac{\partial e}{\partial a},\frac{\partial e}{\partial b},\frac{\partial e}{\partial c} \right) \cdot C }$$

下面给出基于VS2017、Qt5.9和OpenCV430的示例代码,代码中限制了最大步长和最小步长以加快收敛,已通过验证。在此算法调参的过程中可以发现:拟合指数函数调参是比一般的多项式函数更难一点,容易出现梯度爆炸梯度消失的情况,导致无法求解。所以拟合这类函数需要有一组合适的初始解,以便于优化。

void main()
{
    vector<Point2f> points;
    points.push_back({ 0, 1 });
    points.push_back({ -1, 0.125f });
    points.push_back({ 1, 0.125f });
    Matx13f resolve(-2.0f, 0.5f, 2.0f);
    for (int loop = 0; loop < 200; loop++)
    {
        float dea = 0, deb = 0, dec = 0;
        for (const auto& pt : points)
        {
            float g = expf(resolve(0) * pt.x * pt.x + resolve(1) * pt.x + resolve(2));
            float f = g - pt.y;
            dea += 2 * f * g * pt.x * pt.x;
            deb += 2 * f * g * pt.x;
            dec += 2 * f * g;
        }
        dea /= points.size();
        deb /= points.size();
        dec /= points.size();
        dea *= 0.1f;
        deb *= 0.1f;
        dec *= 0.1f;
        float maxv = std::max({ fabs(dea), fabs(deb), fabs(dec) });
        if (maxv > 2)
        {
            dea = dea / maxv * 2;
            deb = deb / maxv * 2;
            dec = dec / maxv * 2;
        }
        else if (maxv < 0.005f)
        {
            dea = dea / maxv * 0.005f;
            deb = deb / maxv * 0.005f;
            dec = dec / maxv * 0.005f;
        }
        resolve(0) -= dea;
        resolve(1) -= deb;
        resolve(2) -= dec;
        qDebug() << dea << deb << dec;
    }
    qDebug() << resolve(0) << resolve(1) << resolve(2) << "-<";
}

 

posted @ 2023-04-10 14:17  兜尼完  阅读(183)  评论(0编辑  收藏  举报