[ceres-solver] AutoDiff
本文的目的是解析 ceres-solver AutoDiff 的实现,说明它是一种类似于 matlab 符号运算的方法。
ceres-solver 使用 ceres::CostFunction
作为计算误差与雅克比的结构。ceres::CostFunction 是一个纯虚类,用户代码继承这个类,并通过实现其纯虚方法 bool Evaluate(double const* const* parameters, double* residuals, double** jacobians);
提供使用待优化参数块(parameters)计算误差(residuals)与雅克比(jacobians) 的方法。对于需要快速验证想法的用户,计算雅克比是繁琐的。
ceres 提供了两种自动计算雅克比的方法——AutoDiff 与 NumericDiff,用户可以分别继承 ceres::AutoDiffCostFunction
、 ceres::NumericDiffCostFunction
以使用这两种方法。选择使用这两种方法之后,用户代码仅需告知 ceres 如何使用 parameters 计算 residuals,至于 jacobians 如何计算,ceres 自行寻找方法。
ceres 的 AutoDiff 使用 Dual Number 计算雅克比。所谓 Dual Number 就是将一个实数写成其自身(为方便将其称为“大量”)与小量(e)的和,并且定义 \(e^2 = 0\) (在计算一阶导数时这么定义)。ceres 实现的 Dual Number 的结构是 ceres::Jet
,Jet 结构中的大量是 T a;
,小量是 Eigen::Matrix<T, N, 1> v;
(此处小量使用一个 Eigen::Vector 表达是介于多元函数对多个变量求导的考虑,后面会解释)。
文件 jet.h
中有一些注释解释 Dual Number 是如何计算导数的。现在摘抄一个注释中的例子如下。
// To handle derivatives of functions taking multiple arguments, different
// infinitesimals are used, one for each variable to take the derivative of. For
// example, consider a scalar function of two scalar parameters x and y:
//
// f(x, y) = x^2 + x * y
//
// Following the technique above, to compute the derivatives df/dx and df/dy for
// f(1, 3) involves doing two evaluations of f, the first time replacing x with
// x + e, the second time replacing y with y + e.
//
// For df/dx:
//
// f(1 + e, y) = (1 + e)^2 + (1 + e) * 3
// = 1 + 2 * e + 3 + 3 * e
// = 4 + 5 * e
//
// --> df/dx = 5
//
// For df/dy:
//
// f(1, 3 + e) = 1^2 + 1 * (3 + e)
// = 1 + 3 + e
// = 4 + e
//
// --> df/dy = 1
//
求函数 f(x, y) = x^2 + x * y
在 (1, 3) 上现在我用微积分的数学方式计算导数。
以上注释说明了使用 Dual Number 计算函数 \(f(x,y)=x^2+xy\) 在 \((1, 3)\) 处对 \(x, y\) 的导数的过程。对 x 的偏导,是将 x 用 Dual Number 1 + e 表示,将 y 用实数 3 表示,代入函数式计算,最终得到的 e 的一次项系数就是函数在 (1, 3) 上对 x 的偏导。实际上这是 L'Hospital 法则的计算机实现。现在使用在微积分中学到的方法计算导数。
注释:(*) 使用一次 L'Hospital 法则,即分子分母分别对 \(\Delta x\) 求一次导数。
分析:在求导数的时候分母一般为 1 次项,即 \(\Delta x\)。使用 L'Hospital 法则,对分母求导,会将 \(\Delta x\) 0 次的项求导消失;而刚好是 \(\Delta x\) 1 次的项求导后是常数;\(\Delta x\) 高于 1 次的项在求导后还会留下 \(\Delta x\),在求极限之后会消失。所以,导数是 1 次项对应的系数,在程序实现中是 e 对应的系数。(但是此处我还没有考虑 \(\Delta x\) 0 次以下的项,现在搞不定。)同理,求二次导数,就是取 \(\Delta x^2\) 的系数。
紧接着下面的注释给出了小量为何使用 Eigen::Vector 表示
的解释。
// To take the gradient of f with the implementation of dual numbers ("jets") in
// this file, it is necessary to create a single jet type which has components
// for the derivative in x and y, and passing them to a templated version of f:
//
// template<typename T>
// T f(const T &x, const T &y) {
// return x * x + x * y;
// }
//
// // The "2" means there should be 2 dual number components.
// // It computes the partial derivative at x=10, y=20.
// Jet<double, 2> x(10, 0); // Pick the 0th dual number for x.
// Jet<double, 2> y(20, 1); // Pick the 1st dual number for y.
// Jet<double, 2> z = f(x, y);
//
// LOG(INFO) << "df/dx = " << z.v[0]
// << "df/dy = " << z.v[1];
//
如果想直接求对所有变量的导数,那么 Dual Number 的 e 的个数就要增加了,有两个变量,就需要 2 个 e,使用 ceres::Jet<double, 2>
。实验验证,可以在使用 AutoDiff 时,于用户代码实现的模板函数 operator() 中故意使模板特例化错误,检查 typename 是否特例化为 Jet<double, [N]>,N 是待优化参数的个数(注意,是 parameters 的个数,不是 parameter blocks 的个数)。
Jet 作为 Dual Number 实现求导具体是要实现求导的一般法则与一些基本函数的导数公式。这些相关的内容可以参考 WikiPedia Differentiation rules。现在对一般法则与基本函数的导数分别举一个在文件 jet.h
中找得到的例子。
“一般法则”举例乘法法则。在 C++ 中基本运算的 operator 仅有 +, -, *, / 四种,仅需对这四种运算实现对应的 Jet operator 即可。在 Python 中有幂运算符 **
,大概 Python 实现还需要考虑这个吧。
乘法法则在数学中可以表达如下。
在 Jet 中对应的 operator*
如下。
template <typename T, int N>
inline Jet<T, N> operator*(const Jet<T, N>& f, const Jet<T, N>& g) {
return Jet<T, N>(f.a * g.a, f.a * g.v + f.v * g.a);
}
“基本函数的导数”举例正弦函数。
正弦函数的导数在数学中表达如下。
在 Jet 中对应的函数 sin
如下。
template <typename T, int N>
inline Jet<T, N> sin(const Jet<T, N>& f) {
return Jet<T, N>(sin(f.a), cos(f.a) * f.v);
}
以上两个例子,代码中形成的 Jet 的小量,对应于数学公式的导数。
另,需要注意在 ceres 中使用 AutoDiff,模板函数 operator() 计算 residuals 过程中使用到的基础函数需要从 ceres 中获得,即不可直接使用 std::sin
函数,应使用 ceres::sin
,以上 sin
函数体内使用到的 sin
, cos
函数是 std::sin
, std::cos
。即 stl 中的模板是无法实例化 ceres::Jet 的。
对于变量负数次幂的处理可以参考代码 operator/(T s, const Jet<T, N>& g)
,即“Scalar 除以 Jet”。
综上所述,ceres-solver 使用 ceres::Jet,实现了 AutoDiff。具体的实现,是通过 ceres::Jet 丰富的 operator 与定义的一系列基本函数(的导数)。