强大的C# Expression在一个函数求导问题中的简单运用

号称面试的题目总是非常有趣的,这里是又一个例子:

【原题出处

http://topic.csdn.net/u/20110928/15/B00A34FE-8544-42E2-A771-3C4A888DB85A.html

【问题梗概】

求一个函数的一阶导数。

【代码方案】

 

  1 namespace Derivative
  2 {
  3     class Program
  4     {
  5         // 求一个节点表达的算式的导函数  
  6         static Expression GetDerivative(Expression node)
  7         {
  8             if (node.NodeType == ExpressionType.Add
  9                 || node.NodeType == ExpressionType.Subtract)
 10             {   // 该节点在做加减法,套用加减法导数公式  
 11                 BinaryExpression binexp = (BinaryExpression)node;
 12                 Expression dleft = GetDerivative(binexp.Left);
 13                 Expression dright = GetDerivative(binexp.Right);
 14                 BinaryExpression resbinexp;
 15 
 16                 if (node.NodeType == ExpressionType.Add)
 17                     resbinexp = Expression.Add(dleft, dright);
 18                 else
 19                     resbinexp = Expression.Subtract(dleft, dright);
 20                 return resbinexp;
 21             }
 22             else if (node.NodeType == ExpressionType.Multiply)
 23             {   // 该节点在做乘法,套用乘法导数公式  
 24                 BinaryExpression binexp = (BinaryExpression)node;
 25                 Expression left = binexp.Left;
 26                 Expression right = binexp.Right;
 27 
 28                 Expression dleft = GetDerivative(left);
 29                 Expression dright = GetDerivative(right);
 30 
 31                 return Expression.Add(Expression.Multiply(dleft, right),
 32                     Expression.Multiply(left, dright));
 33             }
 34             else if (node.NodeType == ExpressionType.Parameter)
 35             {   // 该节点是x本身(叶子节点),故而其导数即常数1  
 36                 return Expression.Constant(1.0);
 37             }
 38             else if (node.NodeType == ExpressionType.Constant)
 39             {   // 该节点是一个常数(叶子节点),故其导数为零  
 40                 return Expression.Constant(0.0);
 41             }
 42             else if (node.NodeType == ExpressionType.Call)
 43             {
 44                 MethodCallExpression callexp = (MethodCallExpression)node;
 45                 Expression arg0 = callexp.Arguments[0];
 46                 // 一下一元函数求导后均需要乘上自变量的导数
 47                 Expression darg0 = GetDerivative(arg0);
 48                 if (callexp.Method.Name == "Exp")
 49                 {
 50                     // 指数函数的导数还是其本身
 51                     return Expression.Multiply(
 52                            Expression.Call(null, callexp.Method, arg0), darg0);
 53                 }
 54                 else if (callexp.Method.Name == "Sin")
 55                 {
 56                     // 正弦函数的倒数是余弦函数
 57                     MethodInfo miCos = typeof(Math).GetMethod("Cos", 
 58                                        BindingFlags.Public | BindingFlags.Static);
 59                     return Expression.Multiply(
 60                            Expression.Call(null, miCos, arg0), darg0);
 61                 }
 62                 else if (callexp.Method.Name == "Cos")
 63                 {
 64                     // 余弦函数的导数是正弦函数的相反数
 65                     MethodInfo miSin = typeof(Math).GetMethod("Sin", 
 66                                        BindingFlags.Public | BindingFlags.Static);
 67                     return Expression.Multiply(
 68                            Expression.Negate(Expression.Call(null, miSin, arg0)), darg0);
 69                 }
 70             }
 71 
 72             throw new NotImplementedException();    // 其余的尚未实现          
 73         }
 74 
 75         static Func<double, double> GetDerivative(Expression<Func<double, double>> func)
 76         {
 77             // 从Lambda表达式中获得函数体  
 78             Expression resBody = GetDerivative(func.Body);
 79 
 80             // 需要续用Lambda表达式的自变量  
 81             ParameterExpression parX = func.Parameters[0];
 82 
 83             Expression<Func<double, double>> resFunc
 84                 = (Expression<Func<double, double>>)Expression.Lambda(resBody, parX);
 85 
 86             Console.WriteLine("diff function = {0}", resFunc);
 87 
 88             // 编译成CLR的IL表达的函数  
 89             return resFunc.Compile();
 90         }
 91 
 92         static double GetDerivative(Expression<Func<double, double>> func, double x)
 93         {
 94             Func<double, double> diff = GetDerivative(func);
 95             return diff(x);
 96         }
 97 
 98         static void Main(string[] args)
 99         {
100             // 举例:求出函数f(x) = cos(x*x)+sin(3*x)+exp(2*x)在x=2.0处的导数  
101             double y = GetDerivative(x => Math.Cos(x*x) + Math.Sin(3*x) + Math.Exp(2*x), 2.0);
102             Console.WriteLine("f'(x) = {0}", y);
103         }
104     }
105 }

【实现大意】

用表达式分解并递归求导(过程是相当容易的,比想象的还容易)。目前只是实现了一个最简单的模型。

【优势】

给出的是解析解,在求导运算方面没有任何数值解的误差,输出运算也是瞬时的,时间复杂度仅和表达式复杂度相关。

【限制】

1. 函数只能以Lambda表达式输入,只能是能求出解析解的表达式

2. 目前只实现了加减法和乘法

【后续扩展】

1. 实现其他运算符(没有太大难度,只是比较繁琐而已)

2. 表达式树优化(也不太难的,根据情况定),最基本的可以从常数乘法开始……

3. 条件运算符的处理(这个会变得极难极复杂,但一定程度上实现分段函数求导),其他特殊情况(对求导还可以,如果考虑求不定积分问题可能会有很多特殊情况和hardcode)

4. 输入端向字符串解析过渡;复杂运算符->逐渐向自定义的数据结构过渡?……

...

【更新】

20110611: 添加三角和指数函数支持,优化仍未进行。

 

posted @ 2011-10-05 10:27  quanben  阅读(1982)  评论(0编辑  收藏  举报