Java8函数之旅 (六) -- 使用lambda实现Java的尾递归

前言

本篇介绍的不是什么新知识,而是对前面讲解的一些知识的综合运用。众所周知,递归是解决复杂问题的一个很有效的方式,也是函数式语言的核心,在一些函数式语言中,是没有迭代与while这种概念的,因为此类的循环通通可以用递归来实现,这类语言的编译器都对递归的尾递归形式进行了优化,而Java的编译器并没有这样的优化,本篇就要完成这样一个对于尾递归的优化。

什么是尾递归

本篇将使用递归中最简单的阶乘计算来作为例子

递归实现

    /**
     * 阶乘计算 -- 递归解决
     *
     * @param number 当前阶乘需要计算的数值
     * @return number!
     */
    public static int factorialRecursion(final int number) {
        if (number == 1) return number;
        else return number * factorialRecursion(number - 1);
    }

这种方法计算阶乘比较大的数很容易就栈溢出了,原因是每次调用下一轮递归的时候在栈中都需要保存之前的变量,所以整个栈结构类似是这样的

5
  4
    3
      2
        1
-------------------> 
	  栈的深度

在没有递归到底之前,那些中间变量会一直保存着,因此每一次递归都需要开辟一个新的栈空间

尾递归实现

任何递归的尾递归版本都十分简单,分析上面栈溢出的原因就是在每次return的时候都会附带一个变量,因此只需要在return的时候不附带这个变量即可。说起来简单,该怎么做呢?其实也很容易,我们使用一个参数来保存上一轮递归的结果,这样就可以了,因此尾递归的阶乘实现应该是这样的代码。

    /**
     * 阶乘计算 -- 尾递归解决
     *
     * @param factorial 上一轮递归保存的值
     * @param number    当前阶乘需要计算的数值
     * @return number!
     */
    public static int factorialTailRecursion(final int factorial, final int number) {
        if (number == 1) return factorial;
        else return factorialTailRecursion(factorial * number, number - 1);
    }

使用一个factorial变量保存上一轮阶乘计算出的数值,这样return的时候就无需保存变量,整个的计算过程是
(5*4)20 -> (20*3) 60 -> (60*2) 120 -> return 120

这样子通过每轮递归结束后刷新当前的栈空间,复用了栈,就克服了递归的栈溢出问题,像这样的return后面不附带任何变量的递归写法,也就是递归发生在函数最尾部,我们称之为'尾递归'。

使用lambda实现编译器的优化

很显然,如果事情这么简单的话,这篇文章也就结束了,和lambda也没啥关系 😃 然而当你调用上文的尾递归写法之后,发现并没有什么作用,该栈溢出的还是会栈溢出,其实原因我在开头就已经说了,尾递归这样的写法本身并不会有什么用,依赖的是编译器对尾递归写法的优化,在很多语言中编译器都对尾递归有优化,然而这些语言中并不包括java,因此在这里我们使用lambda的懒加载(惰性求值)机制来延迟递归的调用,从而实现栈帧的复用。

设计尾递归的接口

因此我们需要设计一个这样的函数接口来代替递归中的栈帧,通过apply这个函数方法(取名叫apply是因为该方法的参数是一个栈帧,返回值也是一个栈帧,类比function接口的apply)完成每个栈帧之间的连接,除此之外,我们栈帧还需要定义几个方法来丰富这个尾递归的接口。

  • apply(连接栈帧,惰性求值)
  • 判断递归是否结束
  • 得到递归最后的结果
  • 执行递归(及早求值)

根据上面的几条定义,设计出如下的尾递归接口

/**
 * 尾递归函数接口
 * @author : martrix
 */
@FunctionalInterface
public interface TailRecursion<T> {
    /**
     * 用于递归栈帧之间的连接,惰性求值
     * @return 下一个递归栈帧
     */
    TailRecursion<T> apply();

    /**
     * 判断当前递归是否结束
     * @return 默认为false,因为正常的递归过程中都还未结束
     */
    default boolean isFinished(){
        return false;
    }

    /**
     * 获得递归结果,只有在递归结束才能调用,这里默认给出异常,通过工具类的重写来获得值
     * @return 递归最终结果
     */
    default T getResult()  {
        throw new Error("递归还没有结束,调用获得结果异常!");
    }

    /**
     * 及早求值,执行者一系列的递归,因为栈帧只有一个,所以使用findFirst获得最终的栈帧,接着调用getResult方法获得最终递归值
     * @return 及早求值,获得最终递归结果
     */
    default T invoke() {
        return Stream.iterate(this, TailRecursion::apply)
                .filter(TailRecursion::isFinished)
                .findFirst()
                .get()
                .getResult();
    }
}

设计对外统一的尾递归包装类

为了达到可以复用的效果,这里设计一个尾递归的包装类,目的是用于对外统一方法,使得需要尾递归的调用同样的方法即可完成尾递归,不需要考虑内部实现细节,因为所有的递归方法无非只有2类类型的元素组成,一个是怎样调用下次递归,另外一个是递归的终止条件,因此包装方法设计为以下两个

  • 调用下次递归
  • 结束本轮递归
    代码如下
/**
 * 使用尾递归的类,目的是对外统一方法
 *
 * @author : Matrix
 */
public class TailInvoke {
    /**
     * 统一结构的方法,获得当前递归的下一个递归
     *
     * @param nextFrame 下一个递归
     * @param <T>       T
     * @return 下一个递归
     */
    public static <T> TailRecursion<T> call(final TailRecursion<T> nextFrame) {
        return nextFrame;
    }

    /**
     * 结束当前递归,重写对应的默认方法的值,完成状态改为true,设置最终返回结果,设置非法递归调用
     *
     * @param value 最终递归值
     * @param <T>   T
     * @return 一个isFinished状态true的尾递归, 外部通过调用接口的invoke方法及早求值, 启动递归求值。
     */
    public static <T> TailRecursion<T> done(T value) {
        return new TailRecursion<T>() {
            @Override
            public TailRecursion<T> apply() {
                throw new Error("递归已经结束,非法调用apply方法");
            }

            @Override
            public boolean isFinished() {
                return true;
            }

            @Override
            public T getResult() {
                return value;
            }
        };
    }
}

完成阶乘的尾递归函数

通过使用上面的尾递归接口与包装类,只需要调用包装了call与done就可以很轻易的写出尾递归函数,代码如下

    /**
     * 阶乘计算 -- 使用尾递归接口完成
     * @param factorial 当前递归栈的结果值
     * @param number 下一个递归需要计算的值
     * @return 尾递归接口,调用invoke启动及早求值获得结果
     */
    public static TailRecursion<Integer> factorialTailRecursion(final int factorial, final int number) {
        if (number == 1)
            return TailInvoke.done(factorial);
        else
            return TailInvoke.call(() -> factorialTailRecursion(factorial + number, number - 1));
    }

通过观察发现,和原先预想的尾递归方法几乎一模一样,只是使用包装类的call与done方法来表示递归的调用与结束
预想的尾递归

    /**
     * 阶乘计算 -- 尾递归解决
     *
     * @param factorial 上一轮递归保存的值
     * @param number    当前阶乘需要计算的数值
     * @return number!
     */
    public static int factorialTailRecursion(final int factorial, final int number) {
        if (number == 1) return factorial;
        else return factorialTailRecursion(factorial * number, number - 1);
    }

测试尾递归函数

这里作一个说明,因为阶乘的计算如果要计算到栈溢出一般情况下Java的数据类型需要使用BigInteger来包装,为了简化代码,这里的测试仅仅是是测试栈会不会溢出的问题,因此我们将操作符的*改成+这样修改的结果仅仅是结果变小了,但是栈的深度却没有改变。测试代码如下
首先测试 深度为10W的普通递归

测试代码

    @Test
    public void testRec() {
        System.out.println(factorialRecursion(100_000));
    }

理所当然的栈溢出了

java.lang.StackOverflowError
	at test.Factorial.factorialRecursion(Factorial.java:20)
	at test.Factorial.factorialRecursion(Factorial.java:20)
	at test.Factorial.factorialRecursion(Factorial.java:20)
	at test.Factorial.factorialRecursion(Factorial.java:20)
	at test.Factorial.factorialRecursion(Factorial.java:20)
	
Process finished with exit code -1

这里我们测试1000W栈帧的尾递归
尾递归代码

    public static TailRecursion<Long> factorialTailRecursion(final long factorial, final long number) {
        if (number == 1)
            return TailInvoke.done(factorial);
        else
            return TailInvoke.call(() -> factorialTailRecursion(factorial + number, number - 1));
    }

测试代码

    @Test
    public void testTailRec() {
        System.out.println(factorialTailRecursion(1,10_000_000).invoke());
    }

发现结果运转良好

50000005000000

Process finished with exit code 0

由于阶乘的计算一般初始值都为1,所以再进一步包装一下,将初始值设置为1

    public static long factorial(final long number) {
        return factorialTailRecursion(1, number).invoke();
    }

最终调用代码如下,完全屏蔽了尾递归的实现细节

    @Test
    public void testTailRec() {
        System.out.println(factorial(10)); //结果为 3628800
    }

总结

本文讲解了利用lambda懒加载的特性完成了递归中栈帧的复用,实现了函数式语言编译器的'尾递归'优化,虽然上面的例子很简单,但是设计的接口和包装类都是通用的,可以说任何需要使用尾递归的都可以使用上面的代码来实现尾递归的优化,这也算是为编译器帮了点忙吧。

上一篇:开始Java8之旅(五) -- Java8中的排序
下一篇:开始Java8之旅(七) -- 函数式备忘录模式优化递归

posted @ 2017-10-24 14:38  祈求者-  阅读(8530)  评论(3编辑  收藏  举报