拟合高斯函数的梯度下降法例子
高斯函数也是一种常见的函数。拟合它可以通过求对数转换成线性规划问题,从而用最小二乘法拟合。不过为了精确一点,可以用最小二乘法拟合得到初始解之后再用梯度下降法求精。以下将描述高斯函数的梯度下降法公式推导过程。高斯函数的形式为:
$${y=a \cdot e^{- \frac{ \left ( x - b \right ) ^{2} } {2c ^{2}} } }$$
它的一般式为:
$${y=e^{ax^{2}+bx+c},\left(a<0\right)}$$
定义误差函数如下:
$${e=\sum_{i}^{}\left ( e^{ax_{i}^{2}+bx_{i}+c}-y_{i} \right )^{2},\left(a<0\right) }$$
分别对${a,b,c}$求偏导数得:
$${ \begin{align} \frac{\partial e}{\partial a}&=\sum_{i}^{}2 f_{i}\left ( a,b,c \right )g_{i}\left ( a,b,c \right )x_{i}^{2} \\ \frac{\partial e}{\partial b}&=\sum_{i}^{}2 f_{i}\left ( a,b,c \right )g_{i}\left ( a,b,c \right )x_{i} \\ \frac{\partial e}{\partial c}&=\sum_{i}^{}2 f_{i}\left ( a,b,c \right )g_{i}\left ( a,b,c \right ) \end{align} }$$
其中:
$${ \begin{align*} g_{i}\left ( a,b,c \right )&=e^{ax_{i}^{2}+bx_{i}+c} \\ f_{i}\left ( a,b,c \right )&=g_{i}\left ( a,b,c \right )-y_{i} \end{align*} }$$
最终迭代公式为:
$${ \left(a,b,c \right)_{n+1} = \left(a,b,c \right)_{n}- \left( \frac{\partial e}{\partial a},\frac{\partial e}{\partial b},\frac{\partial e}{\partial c} \right) \cdot C }$$
下面给出基于VS2017、Qt5.9和OpenCV430的示例代码,代码中限制了最大步长和最小步长以加快收敛,已通过验证。在此算法调参的过程中可以发现:拟合指数函数调参是比一般的多项式函数更难一点,容易出现梯度爆炸或梯度消失的情况,导致无法求解。所以拟合这类函数需要有一组合适的初始解,以便于优化。
void main() { vector<Point2f> points; points.push_back({ 0, 1 }); points.push_back({ -1, 0.125f }); points.push_back({ 1, 0.125f }); Matx13f resolve(-2.0f, 0.5f, 2.0f); for (int loop = 0; loop < 200; loop++) { float dea = 0, deb = 0, dec = 0; for (const auto& pt : points) { float g = expf(resolve(0) * pt.x * pt.x + resolve(1) * pt.x + resolve(2)); float f = g - pt.y; dea += 2 * f * g * pt.x * pt.x; deb += 2 * f * g * pt.x; dec += 2 * f * g; } dea /= points.size(); deb /= points.size(); dec /= points.size(); dea *= 0.1f; deb *= 0.1f; dec *= 0.1f; float maxv = std::max({ fabs(dea), fabs(deb), fabs(dec) }); if (maxv > 2) { dea = dea / maxv * 2; deb = deb / maxv * 2; dec = dec / maxv * 2; } else if (maxv < 0.005f) { dea = dea / maxv * 0.005f; deb = deb / maxv * 0.005f; dec = dec / maxv * 0.005f; } resolve(0) -= dea; resolve(1) -= deb; resolve(2) -= dec; qDebug() << dea << deb << dec; } qDebug() << resolve(0) << resolve(1) << resolve(2) << "-<"; }