C#也玩尾递归

昨晚看到装配脑袋的一篇文章《VS2008亮点:用Lambda表达式进行函数式编程》,介绍了使用Lambda表达式实现递归的一种方法,从评论得知C#下使用特殊的委托也可以实现类似的效果,深受启发,联想到Python中使用抛异常的方式也能实现尾递归,于是尝试着给C#也弄个尾递归的包装来玩玩,纯属娱乐:)

Lambda表达式可以代表一个函数,如 Func<int,int> func = n => n + 1;

由于C#不能用var来推断Lambda表达式的类型,即使C#4也不能用dynamic声明一个Lambda表达式,结果当Lambda表达式的输入输出都是Func<...>时,类型声明会特别长:( 这点还是VB给力,不过VB的Function看着也特累赘:(

要在Lambda表达式里递归调用自身,语句大概是这样:

Func<int,int> func = n => n <= 1? 1 : n * func(n - 1);

当然,这个语句是无法编译的,因为Lambda表达式是匿名函数,直接这么写编译器是不知道func代表啥,于是就有了下面的方案:

第一种是个人觉得最简单的方式,省事:

Func<int,int> func = null;

func = n => n <= 1 ? 1 : n* func(n - 1);

另一种方法折腾多了:将func抽出来通过闭包传递,就是这样:

Func<Func<int,int>,Func<int,int>> func = self => n => n <= 1 ? 1 : n * self(n - 1);

这种方法在调用时需要将函数自身当作参数传给函数自身(好绕口),func是一个迭代式,下面是这个方案的代码:

实现递归调用
private delegate T SelfFunc<T>(SelfFunc<T> self);

public static Func<T, TResult> Recursive<T, TResult>(
Func
<Func<T, TResult>, Func<T, TResult>> func)
{
SelfFunc
<Func<Func<Func<T, TResult>, Func<T, TResult>>, Func<T, TResult>>> s =
y
=> f => p => f(y(y)(f))(p);
return s(s)(func);
}

public static Func<T1, T2, TResult> Recursive<T1, T2, TResult>(
Func
<Func<T1, T2, TResult>, Func<T1, T2, TResult>> func)
{
SelfFunc
<Func<Func<Func<T1, T2, TResult>, Func<T1, T2, TResult>>, Func<T1, T2, TResult>>> s =
y
=> f => (p1, p2) => f(y(y)(f))(p1, p2);
return s(s)(func);
}

public static Func<T1, T2, T3, TResult> Recursive<T1, T2, T3, TResult>(
Func
<Func<T1, T2, T3, TResult>, Func<T1, T2, T3, TResult>> func)
{
SelfFunc
<Func<Func<Func<T1, T2, T3, TResult>, Func<T1, T2, T3, TResult>>, Func<T1, T2, T3, TResult>>> s =
y
=> f => (p1, p2, p3) => f(y(y)(f))(p1, p2, p3);
return s(s)(func);
}

测试代码:

测试代码
[TestMethod()]
public void RecursiveTest1()
{
Func
<int, int> factorial = Functions.Recursive<int, int>(
self
=> n => (n <= 1 ? 1 : n * self(n - 1)));
Assert.AreEqual(
1, factorial(0));
Assert.AreEqual(
1, factorial(1));
Assert.AreEqual(
2, factorial(2));
Assert.AreEqual(
6, factorial(3));
Assert.AreEqual(
24, factorial(4));
Assert.AreEqual(
120, factorial(5));
Assert.AreEqual(
720, factorial(6));
Assert.AreEqual(
5040, factorial(7));
}

[TestMethod()]
public void RecursiveTest2()
{
Func
<int, int, int> factorial = Functions.Recursive<int, int, int>(
self
=> (p, q) => (p == 0 || q == 0) ? 0 : p * q + self(p + 1, q - 1));
Assert.AreEqual(
0, factorial(0, 1));//0
Assert.AreEqual(1, factorial(1, 1));//1*1
Assert.AreEqual(4, factorial(1, 2));//1*2+2*1
Assert.AreEqual(2, factorial(2, 1));//2*1
Assert.AreEqual(7, factorial(2, 2));//2*2+3*1
Assert.AreEqual(65, factorial(3, 5));//3*5+4*4+5*3+6*2+7*1
}

[TestMethod()]
public void RecursiveTest3()
{
Func
<int, int, int, int> factorial = Functions.Recursive<int, int, int, int>(
self
=> (p, q, r) => (p <= 0 || q <= 0 || r <= 0) ? 0 : p * q * r + self(p - 1, q - 2, r - 3));
Assert.AreEqual(
0, factorial(0, 1, 2));//0
Assert.AreEqual(6, factorial(1, 2, 3));//1*2*3
Assert.AreEqual(25, factorial(2, 3, 4));//2*3*4+1*1*1
Assert.AreEqual(68, factorial(3, 4, 5));//3*4*5+2*2*2
}

写完这代码小小地兴奋了会,突然发觉自己好傻,搞那么复杂干嘛,直接在包装函数里去递归不就好了吗?于是就有了下面的简化版本:

简化版本
public static Func<T, TResult> Recursive<T, TResult>(
Func
<Func<T, TResult>, Func<T, TResult>> func)
{
return p => func(Recursive<T, TResult>(func))(p);
}

public static Func<T1, T2, TResult> Recursive<T1, T2, TResult>(
Func
<Func<T1, T2, TResult>, Func<T1, T2, TResult>> func)
{
return (p1, p2) => func(Recursive<T1, T2, TResult>(func))(p1, p2);
}

public static Func<T1, T2, T3, TResult> Recursive<T1, T2, T3, TResult>(
Func
<Func<T1, T2, T3, TResult>, Func<T1, T2, T3, TResult>> func)
{
return (p1, p2, p3) => func(Recursive<T1, T2, T3, TResult>(func))(p1, p2, p3);
}

这时候联想到某狂人也有神一般的代码,使Python实现尾递归(出处):

PY的尾递归
## {{{ http://code.activestate.com/recipes/474088/ (r1)
#
!/usr/bin/env python2.4
#
This program shows off a python decorator(
#
which implements tail call optimization. It
#
does this by throwing an exception if it is
#
it's own grandparent, and catching such
#
exceptions to recall the stack.

import sys

class TailRecurseException:
def __init__(self, args, kwargs):
self.args
= args
self.kwargs
= kwargs

def tail_call_optimized(g):
"""
This function decorates a function with tail call
optimization. It does this by throwing an exception
if it is it's own grandparent, and catching such
exceptions to fake the tail call optimization.

This function fails if the decorated
function recurses in a non-tail context.
"""
def func(*args, **kwargs):
f
= sys._getframe()
if f.f_back and f.f_back.f_back \
and f.f_back.f_back.f_code == f.f_code:
raise TailRecurseException(args, kwargs)
else:
while 1:
try:
return g(*args, **kwargs)
except TailRecurseException, e:
args
= e.args
kwargs
= e.kwargs
func.
__doc__ = g.__doc__
return func

@tail_call_optimized
def factorial(n, acc=1):
"calculate a factorial"
if n == 0:
return acc
return factorial(n-1, n*acc)

print factorial(10000)
# prints a big, big number,
#
but doesn't hit the recursion limit.

@tail_call_optimized
def fib(i, current = 0, next = 1):
if i == 0:
return current
else:
return fib(i - 1, next, current + next)

print fib(10000)
# also prints a big number,
#
but doesn't hit the recursion limit.
#
# end of http://code.activestate.com/recipes/474088/ }}}

这里使用抛出异常的方式跳出了调用栈,实现了尾递归。这段代码和上面的第二种方法其实很类似,都是输入一个函数返回一个处理过的函数。

只要对self这个外部传入的“自调用”函数稍稍改动下,应该也是可以实现C#下的尾递归的:) 

首先是定义一个存储调用参数的异常类型:

private class TailRecursiveException<T> : Exception
{
public TailRecursiveException(T argument)
{
Argument
= argument;
}
public T Argument { get; private set; }
}

然后定义一个回调函数,当然,如果真调用就成递归了,这里是直接抛出异常,并将调用参数存在异常里:

private static TResult TailCallback<T, TResult>(T p)
{
throw new TailRecursiveException<T>(p);
}

最后就是捕获异常,死循环调用的包装:

public static Func<T, TResult> TailRecursive<T, TResult>(
Func
<Func<T, TResult>, Func<T, TResult>> func)
{
return p =>
{
T argument
= p;
while (true)
{
try
{
return func(TailCallback<T, TResult>)(argument);
}
catch (TailRecursiveException<T> e)
{
argument
= e.Argument;
}
}
};
}

下面是完整的代码:

C#的尾递归
public static Func<T, TResult> TailRecursive<T, TResult>(
Func
<Func<T, TResult>, Func<T, TResult>> func)
{
return p =>
{
T argument
= p;
while (true)
{
try
{
return func(TailCallback<T, TResult>)(argument);
}
catch (TailRecursiveException<T> e)
{
argument
= e.Argument;
}
}
};
}

public static Func<T1, T2, TResult> TailRecursive<T1, T2, TResult>(
Func
<Func<T1, T2, TResult>, Func<T1, T2, TResult>> func)
{
return (p1, p2) =>
{
T1 argument1
= p1;
T2 argument2
= p2;
while (true)
{
try
{
return func(TailCallback<T1, T2, TResult>)(argument1, argument2);
}
catch (TailRecursiveException<T1, T2> e)
{
argument1
= e.Argument1;
argument2
= e.Argument2;
}
}
};
}

public static Func<T1, T2, T3, TResult> TailRecursive<T1, T2, T3, TResult>(
Func
<Func<T1, T2, T3, TResult>, Func<T1, T2, T3, TResult>> func)
{
return (p1, p2, p3) =>
{
var argument1
= p1;
var argument2
= p2;
var argument3
= p3;
while (true)
{
try
{
return func(TailCallback<T1, T2, T3, TResult>)(argument1, argument2, argument3);
}
catch (TailRecursiveException<T1, T2, T3> e)
{
argument1
= e.Argument1;
argument2
= e.Argument2;
argument3
= e.Argument3;
}
}
};
}

public static Func<T1, T2, T3, T4, TResult> TailRecursive<T1, T2, T3, T4, TResult>(
Func
<Func<T1, T2, T3, T4, TResult>, Func<T1, T2, T3, T4, TResult>> func)
{
return (p1, p2, p3, p4) =>
{
var argument1
= p1;
var argument2
= p2;
var argument3
= p3;
var argument4
= p4;
while (true)
{
try
{
return func(TailCallback<T1, T2, T3, T4, TResult>)(argument1, argument2, argument3, argument4);
}
catch (TailRecursiveException<T1, T2, T3, T4> e)
{
argument1
= e.Argument1;
argument2
= e.Argument2;
argument3
= e.Argument3;
argument4
= e.Argument4;
}
}
};
}

private static TResult TailCallback<T, TResult>(T p)
{
throw new TailRecursiveException<T>(p);
}

private static TResult TailCallback<T1, T2, TResult>(T1 p1, T2 p2)
{
throw new TailRecursiveException<T1, T2>(p1, p2);
}

private static TResult TailCallback<T1, T2, T3, TResult>(T1 p1, T2 p2, T3 p3)
{
throw new TailRecursiveException<T1, T2, T3>(p1, p2, p3);
}

private static TResult TailCallback<T1, T2, T3, T4, TResult>(T1 p1, T2 p2, T3 p3, T4 p4)
{
throw new TailRecursiveException<T1, T2, T3, T4>(p1, p2, p3, p4);
}

private class TailRecursiveException<T> : Exception
{
public TailRecursiveException(T argument)
{
Argument
= argument;
}
public T Argument { get; private set; }
}

private class TailRecursiveException<T1, T2> : Exception
{
public TailRecursiveException(T1 p1, T2 p2)
{
Argument1
= p1;
Argument2
= p2;
}
public T1 Argument1 { get; private set; }
public T2 Argument2 { get; private set; }
}

private class TailRecursiveException<T1, T2, T3> : Exception
{
public TailRecursiveException(T1 p1, T2 p2, T3 p3)
{
Argument1
= p1;
Argument2
= p2;
Argument3
= p3;
}
public T1 Argument1 { get; private set; }
public T2 Argument2 { get; private set; }
public T3 Argument3 { get; private set; }
}

private class TailRecursiveException<T1, T2, T3, T4> : Exception
{
public TailRecursiveException(T1 p1, T2 p2, T3 p3, T4 p4)
{
Argument1
= p1;
Argument2
= p2;
Argument3
= p3;
Argument4
= p4;
}
public T1 Argument1 { get; private set; }
public T2 Argument2 { get; private set; }
public T3 Argument3 { get; private set; }
public T4 Argument4 { get; private set; }
}

由于想不到单个参数的测试代码怎么写比较有效,阶乘的话一下子就达到int上限,于是采用Fib数列来测试,测试这个要4个参数的版本:

var fib = TailRecursive<int, int, int, int, int>(self => (i, n, a, b) =>
{
Console.WriteLine(
"=>{0}/{1}:{2},{3}", i, n, a, b); //性能测试时注释这句
return (n < 3 ? 1 : (i == n ? a + b : self(i + 1, n, b, a + b)));
});
Action
<int> Fib = n =>
{
Console.WriteLine(
"Fib({0}) = ?", n);
var result
= fib(3, n, 1, 1);
Console.WriteLine(
"Fib({0}) = {1}", n, result);
};
Fib(
1);
Fib(
2);
Fib(
3);
Fib(
4);
Fib(
6);
Fib(
10);
Fib(
30);

输出结果看起来没问题:

测试结果
Fib(1) = ?
=>3/1:1,1
Fib(
1) = 1
Fib(
2) = ?
=>3/2:1,1
Fib(
2) = 1
Fib(
3) = ?
=>3/3:1,1
Fib(
3) = 2
Fib(
4) = ?
=>3/4:1,1
=>4/4:1,2
Fib(
4) = 3
Fib(
6) = ?
=>3/6:1,1
=>4/6:1,2
=>5/6:2,3
=>6/6:3,5
Fib(
6) = 8
Fib(
10) = ?
=>3/10:1,1
=>4/10:1,2
=>5/10:2,3
=>6/10:3,5
=>7/10:5,8
=>8/10:8,13
=>9/10:13,21
=>10/10:21,34
Fib(
10) = 55
Fib(
30) = ?
=>3/30:1,1
=>4/30:1,2
=>5/30:2,3
=>6/30:3,5
=>7/30:5,8
=>8/30:8,13
=>9/30:13,21
=>10/30:21,34
=>11/30:34,55
=>12/30:55,89
=>13/30:89,144
=>14/30:144,233
=>15/30:233,377
=>16/30:377,610
=>17/30:610,987
=>18/30:987,1597
=>19/30:1597,2584
=>20/30:2584,4181
=>21/30:4181,6765
=>22/30:6765,10946
=>23/30:10946,17711
=>24/30:17711,28657
=>25/30:28657,46368
=>26/30:46368,75025
=>27/30:75025,121393
=>28/30:121393,196418
=>29/30:196418,317811
=>30/30:317811,514229
Fib(
30) = 832040

最后我还测试了下Fib(1000000),结果等了几分钟,i>180000之后还没见到堆栈溢出的异常,应该算是测试通过了吧。

老赵也曾介绍过MSIL的tail指令可以实现尾递归,效率应该比抛出异常高很多,只是个人觉得在IL层面上添加一个指令比较麻烦,不知大家有没其他更简单好用的方法呢?

================================================================

思考了一个晚上,终于想到一个更好的办法来实现尾递归,思路如下:

1. 定义一个回调的类存储当前回调的参数,以及用户是否发起了回调操作IsCallback

public class RecFunc<T1, T2, T3, T4, TResult>
{
public RecFunc(T1 p1, T2 p2, T3 p3, T4 p4)
{
IsCallback
= false;
Argument1
= p1;
Argument2
= p2;
Argument3
= p3;
Argument4
= p4;
}
public bool IsCallback { get; internal set; }
public T1 Argument1 { get; private set; }
public T2 Argument2 { get; private set; }
public T3 Argument3 { get; private set; }
public T4 Argument4 { get; private set; }
public TResult Callback(T1 p1, T2 p2, T3 p3, T4 p4)
{
Argument1
= p1;
Argument2
= p2;
Argument3
= p3;
Argument4
= p4;
IsCallback
= true;
return default(TResult);
}
}

2.外部代码根据这个IsCallback决定是继续循环还是返回结果。

public static Func<T1, T2, T3, T4, TResult> TailRecursive<T1, T2, T3, T4, TResult>(
Func
<RecFunc<T1, T2, T3, T4, TResult>, T1, T2, T3, T4, TResult> func)
{
return (p1, p2, p3, p4) =>
{
var rec
= new RecFunc<T1, T2, T3, T4, TResult>(p1, p2, p3, p4);
while (true)
{
var result
= func(rec, rec.Argument1, rec.Argument2, rec.Argument3, rec.Argument4);
if (rec.IsCallback)
{
rec.IsCallback
= false;
continue;
}
return result;
}
};
}

测试代码需要稍微改下:

var fib = TailRecursive<int, int, int, int, int>((rec, i, n, a, b) =>
{
Console.WriteLine(
"=>{0}/{1}:{2},{3}", i, n, a, b); //性能测试时注释这句
return (n < 3 ? 1 : (i == n ? a + b : rec.Callback(i + 1, n, b, a + b)));
});
Action
<int> Fib = n =>
{
Console.WriteLine(
"Fib({0}) = ?", n);
var result
= fib(3, n, 1, 1);
Console.WriteLine(
"Fib({0}) = {1}", n, result);
};
Fib(
1);
Fib(
2);
Fib(
3);
Fib(
4);
Fib(
6);
Fib(
10);
Fib(
30);
Fib(
1000000);

将测试代码的Console.WriteLine("=>{0}/{1}:{2},{3}", i, n, a, b);注释后,运行10次Fib(100000000),在我机子上平均执行时间是1933ms。

如果采用原先抛异常的方式,也把里面的WriteLine注释掉,运行10次Fib(100000)平均执行时间已经是3746ms,至于Fib(1000000)甚至Fib(100000000),等半天没结果。

RecFunc<>应该也可以将所有调用参数存储为object[],因为我个人很喜欢VS的智能提示和类型推断,而且即使这么做外面的包装函数还是要做多个重载,感觉体现不出Invoke(object[])的优势,所以没做这方面尝试。

这种方式比原先抛出异常的方式减少了大量的new()操作,也不需要CLR的异常处理来实现,同时还少了一层Lambda表达式,理论上性能会好很多:)

posted @ 2011-02-19 18:39  neutra  阅读(3063)  评论(5编辑  收藏  举报