强大的C# Expression在一个函数求导问题中的简单运用
号称面试的题目总是非常有趣的,这里是又一个例子:
【原题出处】
http://topic.csdn.net/u/20110928/15/B00A34FE-8544-42E2-A771-3C4A888DB85A.html
【问题梗概】
求一个函数的一阶导数。
【代码方案】
1 namespace Derivative 2 { 3 class Program 4 { 5 // 求一个节点表达的算式的导函数 6 static Expression GetDerivative(Expression node) 7 { 8 if (node.NodeType == ExpressionType.Add 9 || node.NodeType == ExpressionType.Subtract) 10 { // 该节点在做加减法,套用加减法导数公式 11 BinaryExpression binexp = (BinaryExpression)node; 12 Expression dleft = GetDerivative(binexp.Left); 13 Expression dright = GetDerivative(binexp.Right); 14 BinaryExpression resbinexp; 15 16 if (node.NodeType == ExpressionType.Add) 17 resbinexp = Expression.Add(dleft, dright); 18 else 19 resbinexp = Expression.Subtract(dleft, dright); 20 return resbinexp; 21 } 22 else if (node.NodeType == ExpressionType.Multiply) 23 { // 该节点在做乘法,套用乘法导数公式 24 BinaryExpression binexp = (BinaryExpression)node; 25 Expression left = binexp.Left; 26 Expression right = binexp.Right; 27 28 Expression dleft = GetDerivative(left); 29 Expression dright = GetDerivative(right); 30 31 return Expression.Add(Expression.Multiply(dleft, right), 32 Expression.Multiply(left, dright)); 33 } 34 else if (node.NodeType == ExpressionType.Parameter) 35 { // 该节点是x本身(叶子节点),故而其导数即常数1 36 return Expression.Constant(1.0); 37 } 38 else if (node.NodeType == ExpressionType.Constant) 39 { // 该节点是一个常数(叶子节点),故其导数为零 40 return Expression.Constant(0.0); 41 } 42 else if (node.NodeType == ExpressionType.Call) 43 { 44 MethodCallExpression callexp = (MethodCallExpression)node; 45 Expression arg0 = callexp.Arguments[0]; 46 // 一下一元函数求导后均需要乘上自变量的导数 47 Expression darg0 = GetDerivative(arg0); 48 if (callexp.Method.Name == "Exp") 49 { 50 // 指数函数的导数还是其本身 51 return Expression.Multiply( 52 Expression.Call(null, callexp.Method, arg0), darg0); 53 } 54 else if (callexp.Method.Name == "Sin") 55 { 56 // 正弦函数的倒数是余弦函数 57 MethodInfo miCos = typeof(Math).GetMethod("Cos", 58 BindingFlags.Public | BindingFlags.Static); 59 return Expression.Multiply( 60 Expression.Call(null, miCos, arg0), darg0); 61 } 62 else if (callexp.Method.Name == "Cos") 63 { 64 // 余弦函数的导数是正弦函数的相反数 65 MethodInfo miSin = typeof(Math).GetMethod("Sin", 66 BindingFlags.Public | BindingFlags.Static); 67 return Expression.Multiply( 68 Expression.Negate(Expression.Call(null, miSin, arg0)), darg0); 69 } 70 } 71 72 throw new NotImplementedException(); // 其余的尚未实现 73 } 74 75 static Func<double, double> GetDerivative(Expression<Func<double, double>> func) 76 { 77 // 从Lambda表达式中获得函数体 78 Expression resBody = GetDerivative(func.Body); 79 80 // 需要续用Lambda表达式的自变量 81 ParameterExpression parX = func.Parameters[0]; 82 83 Expression<Func<double, double>> resFunc 84 = (Expression<Func<double, double>>)Expression.Lambda(resBody, parX); 85 86 Console.WriteLine("diff function = {0}", resFunc); 87 88 // 编译成CLR的IL表达的函数 89 return resFunc.Compile(); 90 } 91 92 static double GetDerivative(Expression<Func<double, double>> func, double x) 93 { 94 Func<double, double> diff = GetDerivative(func); 95 return diff(x); 96 } 97 98 static void Main(string[] args) 99 { 100 // 举例:求出函数f(x) = cos(x*x)+sin(3*x)+exp(2*x)在x=2.0处的导数 101 double y = GetDerivative(x => Math.Cos(x*x) + Math.Sin(3*x) + Math.Exp(2*x), 2.0); 102 Console.WriteLine("f'(x) = {0}", y); 103 } 104 } 105 }
【实现大意】
用表达式分解并递归求导(过程是相当容易的,比想象的还容易)。目前只是实现了一个最简单的模型。
【优势】
给出的是解析解,在求导运算方面没有任何数值解的误差,输出运算也是瞬时的,时间复杂度仅和表达式复杂度相关。
【限制】
1. 函数只能以Lambda表达式输入,只能是能求出解析解的表达式
2. 目前只实现了加减法和乘法
【后续扩展】
1. 实现其他运算符(没有太大难度,只是比较繁琐而已)
2. 表达式树优化(也不太难的,根据情况定),最基本的可以从常数乘法开始……
3. 条件运算符的处理(这个会变得极难极复杂,但一定程度上实现分段函数求导),其他特殊情况(对求导还可以,如果考虑求不定积分问题可能会有很多特殊情况和hardcode)
4. 输入端向字符串解析过渡;复杂运算符->逐渐向自定义的数据结构过渡?……
...
【更新】
20110611: 添加三角和指数函数支持,优化仍未进行。
enjoy every minute of an appless, googless and oracless life