Java实现求解二项式系数及代码重构

      摘要: 通过代码重构,优化二项式系数求解。包括:使用动态规划法和值对象节省空间效率;接口改造;大整数支持等。

      难度: 初级 

 

      在上一篇文章中,我总结了从阅读《编程珠玑I》中获得的一些启示。其中有非常重要的一条:代码重审和回顾。通过对以前写过的代码进行重新审视和改进(以现在的经验),使之更具实用性,从而学习新的东西。你敢于面对以前写过的代码吗?如果你都不敢面对,谁还能有这个勇气?

        作为代码重审和回顾的一个例子,我对以前的一个粗糙的二项式定理实现进行了重审和改写。当时,主要是为了学习动态规划法技术,运用它来计算二项式系数。

 

            简单回顾二项式定理的相关知识:

           (a+b)^n= a^n + C(n,1)a^(n-1)b+...+C(n,k)a^(n-k)b^k +...+ C(n,n-1)ab^(n-1) +b^n

           其中: C(n,0)= C(n,n) = 1; C(n,k) = C(n, n-k); C(n,k)= C(n-1, k-1) + C(n-1,k)

 

           计算方法: 例如C(4,2)

            S1:构造数组: a[0:4][0:4]= 0

            S2:a[0][0] = 1;

            S3:a[1][0] = 1; a[1][1] = 1;

            S4:a[2][0] = 1; a[2][1] = 2; a[2][2] = 1

            S5:a[3][0] = 1; a[3][1] = 3; a[3][2] = 3; a[3][3] = 1

            S6:a[4][0] = 1; a[4][1] = 4; a[4][2] = 6

        如上所示,自上向下运用公式计算即可得到a[4][2] = 6.动态规划法通过保存并重用已经计算的值,从而避免不必要的计算时间,提高运行时间效率;其代价是一定的空间效率。这里,空间效率是O(n^2);后面可以看到,通过重用空间,空间效率可以降低到O(n)(否则,很快就内存不足了)。

 

        如何改写呢?首先要定义好类的接口和功能。对于该二项式类BinomialTheorem(命名要贴切,英文不知道如何拼写?搜索!):

          1.  提供唯一参数为 二项式的幂 ,通过公共构造器传入;

          2.  显示该二项式展开式的字符串表示。

               这里,仅仅只留出两个公共接口。简单的接口可使类更易使用;此外,在改写的过程中,发现要提供一个计算组合系数C(n,k)的便利方法。

 

         改进的几个方面:

         [1]  将其改为值对象。值对象是状态不可变的对象;对于同样的值应当返回同一个实例。值对象需要覆写equals 和hashCode方法;而如何使得对于同一个值参数返回同一个实例呢?参考Integer.ValueOf()实现,定义了一个内部嵌套类BinoTheoremCache,存放256个缓存BinomialTheorem对象。若取BinomialTheorem(i), 0<=i < 256 , 那么,直接从缓存中取;否则,使用 new来获取。这也说明一点:如果不太清楚某某怎么实现的,可以参考JDK源码来获取思路和启示。

        [2]  大整数: 计算到 C(60,30) 就发现整型不够用了。n= 60 就无法获得正确结果,这个类是毫无用处的。必须使用大整数;这里直接使用了java提供的 BigInteger. 具体实现是如何呢?留待以后推敲。这里仅仅给出猜想:应该是用整型数组拼接而成。int i可以表示 2147483647; 那么 int[]arr = new int[2] ; arr 可以表示到21474836472147483647这么大的数(转换为字符串形式)。每个整数用一个整型数组表示显然浪费空间;那么可以分段表示:若 i<= 2147483647 , 则用含一个整数的数组表示; 若 2147483647< I <= 21474836472147483647;则用含两个整数的数组表示。依次类推。接下来就需要实现数组的加法、减法、乘法等。

       [3]  空间效率: 如果采用初始的a[0:n][0:n]那么,其空间效率是O(n^2);当 n= 10000 时;int[0:10000][0:10000]需要 10000*10000*4/1024/1024= 381.47MB 空间;对于n= 100000 就无能为力了;因此,如果总是囿于小数据处理,那么,会有一种“取之不尽,用之不竭”的错觉;一旦深入到大数据集的处理领域,就不得不经常面对JavaOutOfMemoryError 了。

 

       如何改进其空间效率呢?思路是直观的:重用空间。可以发现:A. 考虑到对称性,C(n,k)= C(n, n-k) , 实际上只需要a[0:n][0:n/2]; B. 当矩阵a[j-1][i]用来计算 a[j][i] i = 0,1,..., n/2 ; 1 <=j <= n之后, a[j-1][i]便不再起作用;因此,可以重用a[j-1][i], i= 0,1,..., n/2 的空间; 于是,只需要a[0:1][0:n/2]就可以了; 进一步地, 可以只需要a[0:n/2]空间 ,不过必须要倒着计算:

 

            假设a[0: 2] 存储着C(4,0) , C(4,1), C(4,2) ; 即:

            a[0]= C(4,0), a[1] = C(4,1) , a[2] = C(4,2) ; 则

            S1:a[2] = C(5,2) = 2C(4,2) = 2 *a[2];

            S2:a[1] = C(5,1) = C(4,0) + C(4,1) = a[0] + a[1];

            S3:a[0] = C(5,0) = 1;

 

             如果顺着计算,会产生覆盖:

            S1:a[0] = C(5,0) =1;

            S2:a[1] = C(5,1) =C(4,0) +C(4,1)

            S3: a[2] = C(5,2) = 2C(4,1) //! a[1]已被覆盖为C(5,1)

            这里实际上说明了一个比较普遍的现象:虽然动态规划法通常牺牲一定的空间来换取时间效率;但空间效率通常是有一定的提升和优化的空间的。

 

         [4] 一个简单的运行时间测量框架

        测量方法运行时间是一个较为频繁的操作,尤其是当实现一定的算法,并期望知道其性能的时候。通过写一个比较简单的运行时间测量框架,可以方便以后的算法性能测量应用。目前还只能接受单个问题规模参数的测试。这与程序测试等其实都是一本万利的事情。最初,可能觉得很麻烦,一段时间熟悉后,当你掌握相关方法和技术后,开发速度就自然提上去了。熟能生巧。

 

        [5]小结:

        不得不说,在这个代码重审和改进的过程中,确实学到了不少东西,也体验到了不断精益求精的一些感受,很好很充实。一个类,要写好可不容易!当然,改写后的最终程序并不一定就有多完善,不过,实用性是有很大提升的。程序中有不足之处,还恳请读者指正。

 

     Java 代码实现:

     BinomialTheorem.java         

package algorithm.dynamicplan;
import java.math.BigInteger;
/**
 * BinomialTheorem
 * 计算二项式系数及展开式
 * 
 */
public class BinomialTheorem  {  
    
    /** 二项式的幂 */
    private final int power;
    
    /** 二项式展开式的系数向量,binomialCoeffs[i] 存储 C(power,i) */
    private BigInteger[] binomialCoeffs;
    
    // 标记: 如果已经计算过该对象的二项式系数向量,则不必再重新计算
    private boolean flag = false;
    
    public BinomialTheorem(int power) 
    {    
        this.power = power;
        if (binomialCoeffs == null) {
            binomialCoeffs = new BigInteger[power/2+1];
        }
    }
    
    public static BinomialTheorem getInstance(int power)
    {
        if (power < 0) {
            throw new IllegalArgumentException("参数错误,指定二项式的幂必须为正整数!");
        }
        if (power < 256) {
            return BinoTheoremCache.cache[power];
        }
        else {
            return new BinomialTheorem(power);
        }
    }
    
    private static class BinoTheoremCache {
        
        private BinoTheoremCache() { }
        
        private static BinomialTheorem[] cache = new BinomialTheorem[256];
        static {
            for (int i=0; i < 256; i++) {
                cache[i] = new BinomialTheorem(i);
            }
        }
    }
    
    
    /** 
     * getBinomial : 获得二项式展开式的字符串表达形式
     * (a+b)^n = a^n + C(n,1)*a^(n-1)*b + ... + C(n,k)*a^(n-k)*b^k + C(n,n-1)*a*b^(n-1) + b^n
     */
    
    public String toString() {
        if (flag == false) {  // 计算二项式展开式的系数向量,且仅仅计算一次
            calcBinomialCoeffs();   
        }
        
        if (power == 0) {
            return "(a+b)^0 = 1";
        }
        
        String beginString = "(a+b)" + "^" + power + " = /n";
        StringBuilder result = new StringBuilder(beginString);
        
        for (int i = 0; i <= power; i++) {
                
            int equivIndex = (i <= power/2) ? i : (power-i); 
            
            // 衔接二项式系数 C(n,i) : 若 C(n,i) = 1 则省略不显示; 若 i > power/2 , 则 i = power - power/2 .
            result.append((binomialCoeffs[equivIndex].compareTo(BigInteger.valueOf(1)) == 0) ? "" : binomialCoeffs[equivIndex]);
            result.append(displayTerm("a", power-i));  // 衔接 * a^(n-i)
            result.append(displayTerm("b",i)); // 衔接 * b^i
            result.append(" + ");
            if (i % 10 == 9) {
                result.append('/n');
            }
        }
        result.deleteCharAt(result.length()-2);
        
        return result.toString();
        
    }
    
    /*
     * displayTerm : 显示二项式的项 term^power
     * 若 power = 0 ,则不显示该项 ; 若 power = 1, 则只显示 term
     * 若 power > 1 , 则显示 term^power*
     */
    private String displayTerm(String term, int power)
    {
        if (power == 0) { return "" ; }
        if (power == 1) { return term; }
        return "(" + term + "^" + power + ")";
    }
    
    /**
     * combinNum : 计算组合数的便利方法 
     * @return 组合数 C(n,k)
     */
    public static BigInteger combinNum(int n, int k)
    {
        if (n < 0 || k < 0 || n < k)
            throw new IllegalArgumentException();
        int finalk = (k<=n/2) ? k : (n-k);
        BigInteger[] coeffs = new BigInteger[finalk+1];
        for (int i=0; i < coeffs.length; i++) {
            coeffs[i] = BigInteger.valueOf(0);
        }
        calcBinomialCoeffs(coeffs, n);
        return coeffs[finalk];
    }
    
    /**
     * binomialCoeff: 计算二项式展开的系数 C(n,0) --- C(n,n)
     * C(n,k) = C(n,n-k) 
     * C(n,k) = C(n-1,k) + C(n-1, k-1)
     * C(n,0) = C(n,n) = 1
     * 
     */
    public void calcBinomialCoeffs() {
        
        for (int i=0; i < binomialCoeffs.length; i++) {
            binomialCoeffs[i] = BigInteger.valueOf(0);
        }
        calcBinomialCoeffs(this.binomialCoeffs, power);
        flag = true;
    }
    
    /*
     * calcBinomialCoeff: 计算二项式展开的系数
     * C(n,k) = C(n,n-k) 
     * C(n,k) = C(n-1,k) + C(n-1, k-1)
     * C(n,0) = C(n,n) = 1
     * 
     */
    private static void calcBinomialCoeffs(BigInteger[] binomialCoeffs, int powerNum)
    {
        // binomialCoeffs : 存储 C(powerNum, j) = C(powerNum-1 , j-1) + C(powerNum-1, j)
        for (int i = 0; i <= powerNum; i++) {
            int upperIndex = Math.min(i/2, binomialCoeffs.length-1);
            for (int k = upperIndex; k >= 0; k--) { 
                if (k == 0 || i == k)
                    binomialCoeffs[k] = BigInteger.valueOf(1);
                else if (2*k == i) {
                    binomialCoeffs[k] = binomialCoeffs[k-1].multiply(BigInteger.valueOf(2));
                } else {
                    binomialCoeffs[k] = binomialCoeffs[k-1].add(binomialCoeffs[k]);    
                }
            }
        }
    }
    
    public boolean equals(Object obj)
    {
        if (!(obj instanceof BinomialTheorem)) { return false; }
        return ((BinomialTheorem)obj).power == power;
    }
    
    public int hashCode()
    {
        return power;
    }
} 

 

     一个简单的运行时间测试框架   RuntimeMeasurement.java 

   

package common;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
public class RuntimeMeasurement {
    
    public RuntimeMeasurement(int maxsize) {
        this.maxsize = maxsize;
        time = new double[maxsize];
    }
    
    // 问题最大规模: 以 10 的 size 次幂计
    private int maxsize ;
    
    // 运行时间以 ms 计
    private double[] time ;
    
    /**
     * measureTime : 对指定类型的对象调用指定参数列表的指定方法,并测量其运行时间
     * @param type  指定对象类型,必须有一个 参数类型为 int 的公共构造器方法
     * @param methodName  指定测试方法名称,要求是空参数列表
     */
    public void measureTime(Class<?> type, String methodName)
    {
        try {
             Constructor<?> con = type.getConstructor(int.class);
             Method testMethod = null;
             for (int i = 0; i < time.length; i++) {    
                  Object obj = con.newInstance(power10(i+1));
                  testMethod = type.getMethod(methodName, new Class<?>[]{});
                  long start = System.nanoTime();  
                  testMethod.invoke(obj, new Object[] {});
                  long end = System.nanoTime();
                  time[i] =  ((end - start) / (double)1000000) ;       
                }
              
        } catch (Exception e) {
                e.printStackTrace();
                System.out.println(e.getMessage());
        }
            
    }
    
    /**
     * showTime : 显示已经测量获得的运行时间,在 measureTime 方法调用后调用该方法。
     */
    public void showTime()
    {
        for (int i=0; i < time.length; i++) {
            System.out.printf("n = %12d : " , power10(i+1));    
            System.out.printf("%12.3f/n", time[i]);    
        }
    }
    private int power10(int n)
    {
        int result = 1;
        while (n > 0) {
            result *= 10;
            n--;
        }
        return result;
    }
    
} 

 

    BinomialTheorem 测试类 : BinomialTest.java

package algorithm.dynamicplan;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import common.RuntimeMeasurement;
public class BinomialTest {
    
    public static void autoTest()
    {    
        for (int i = 0; i < 10; i++) {
            try {
              System.out.println(BinomialTheorem.getInstance(i));
            } catch (Exception e) {
                e.printStackTrace();
                System.out.println(e.getMessage());
            }
        }
    }
    
    public static void measureTime()
    {
        RuntimeMeasurement rm = new RuntimeMeasurement(4);
        rm.measureTime(BinomialTheorem.class, "calcBinomialCoeffs");
        rm.showTime();
    }
    
    public static void handTest()
    {
        try {
            BufferedReader stdin = new BufferedReader(new InputStreamReader(System.in));
            System.out.printf("请输入任意数字: ");
            String param = "";
            try {
                while ((param = stdin.readLine()) != null && param.length() != 0) {
                    if (! param.matches("//s*0|[1-9][0-9]*//s*")) 
                        throw new NumberFormatException();
                       
                    BinomialTheorem bino = BinomialTheorem.getInstance(Integer.parseInt(param));
                    System.out.println(bino);               
                }   
            } finally {
                stdin.close();
            }
                   
        }
        catch (IOException ioe) {
            ioe.printStackTrace();
            System.out.println("IO 出错: " + ioe.getMessage());
        }
        catch(NumberFormatException nfe) {
            nfe.printStackTrace();
            System.out.println("输入有误,请按格式输入:" + nfe.getMessage());
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }
    
    public static void testCombinNum()
    {
        for (int k = 1; k <= 20; k++) {
            System.out.printf("C(%d,%d) = %d/n", 20, k, BinomialTheorem.combinNum(20, k));
        }
        
        for (int n = 10; n <= 100000000; n *= 10) {
            System.out.printf("C(%d,%d) = %d/n", n, 10, BinomialTheorem.combinNum(n, 10));
        }
        
        for (int k=10; k <= 100; k += 10) {
            System.out.printf("C(%d,%d) = %d/n", k, k/2, BinomialTheorem.combinNum(k, k/2));
        }
    }
    
    public static void testValueObject()
    {
         BinomialTheorem bt3 = BinomialTheorem.getInstance(3);         
         BinomialTheorem another3 = BinomialTheorem.getInstance(3);
         System.out.println("bt3 hashCode: " + bt3.hashCode());
         System.out.println("another3 hashCode: " + another3.hashCode());
         System.out.println("another3 == bt3 ? " + (another3 == bt3));
         System.out.println("another3.equals(bt3) ? " + another3.equals(bt3));
         
         Integer int100 = Integer.valueOf(100);
         Integer ano100 = Integer.valueOf(100);
         System.out.println("int100 hashCode: " + int100.hashCode());
         System.out.println("ano100 hashCode: " + ano100.hashCode());
         System.out.println("int100 == ano100 ? " + (int100 == ano100));
         System.out.println("int100.equals(ano100) ? " + int100.equals(ano100));
         
         String string = "am I happy ? ";
         String anoString = "am I happy ? ";
         System.out.println("string hashCode: " + string.hashCode());
         System.out.println(" anoString hashCode: " + anoString.hashCode());
         System.out.println("string == anoString ? " + (string == anoString));
         System.out.println("string.equals(anoString) ? " + string.equals(anoString));
    }
    public static void main(String[] args) 
    {
        
        System.out.println("默认 JVM 最大内存: " + Runtime.getRuntime().maxMemory());
        
         System.out.println("********** 值对象性质检验 ************");
         testValueObject();
        
         System.out.println("--------- 求二项式的幂的展开式: ------------ ");
        
         System.out.println("********** 自动测试小实例 ************");
         autoTest();
         
         System.out.println("********** 手动测试实例 ************");
         handTest();     
         
         System.out.println("******* 计算二项式系数的运行时间测试(ms) *********");
         measureTime();
         
         System.out.println("********** 计算组合数 ************");
         testCombinNum();    
           
         System.out.println(" JVM 已占用总内存: " + Runtime.getRuntime().totalMemory());
         System.out.println(" JVM 可用内存: " + Runtime.getRuntime().freeMemory());
         
    }
}

 

posted @ 2011-04-22 17:33  琴水玉  阅读(1844)  评论(0编辑  收藏  举报