Java实现求解二项式系数及代码重构
摘要: 通过代码重构,优化二项式系数求解。包括:使用动态规划法和值对象节省空间效率;接口改造;大整数支持等。
难度: 初级
在上一篇文章中,我总结了从阅读《编程珠玑I》中获得的一些启示。其中有非常重要的一条:代码重审和回顾。通过对以前写过的代码进行重新审视和改进(以现在的经验),使之更具实用性,从而学习新的东西。你敢于面对以前写过的代码吗?如果你都不敢面对,谁还能有这个勇气?
作为代码重审和回顾的一个例子,我对以前的一个粗糙的二项式定理实现进行了重审和改写。当时,主要是为了学习动态规划法技术,运用它来计算二项式系数。
简单回顾二项式定理的相关知识:
(a+b)^n= a^n + C(n,1)a^(n-1)b+...+C(n,k)a^(n-k)b^k +...+ C(n,n-1)ab^(n-1) +b^n
其中: C(n,0)= C(n,n) = 1; C(n,k) = C(n, n-k); C(n,k)= C(n-1, k-1) + C(n-1,k)
计算方法: 例如C(4,2)
S1:构造数组: a[0:4][0:4]= 0
S2:a[0][0] = 1;
S3:a[1][0] = 1; a[1][1] = 1;
S4:a[2][0] = 1; a[2][1] = 2; a[2][2] = 1
S5:a[3][0] = 1; a[3][1] = 3; a[3][2] = 3; a[3][3] = 1
S6:a[4][0] = 1; a[4][1] = 4; a[4][2] = 6
如上所示,自上向下运用公式计算即可得到a[4][2] = 6.动态规划法通过保存并重用已经计算的值,从而避免不必要的计算时间,提高运行时间效率;其代价是一定的空间效率。这里,空间效率是O(n^2);后面可以看到,通过重用空间,空间效率可以降低到O(n)(否则,很快就内存不足了)。
如何改写呢?首先要定义好类的接口和功能。对于该二项式类BinomialTheorem(命名要贴切,英文不知道如何拼写?搜索!):
1. 提供唯一参数为 二项式的幂 ,通过公共构造器传入;
2. 显示该二项式展开式的字符串表示。
这里,仅仅只留出两个公共接口。简单的接口可使类更易使用;此外,在改写的过程中,发现要提供一个计算组合系数C(n,k)的便利方法。
改进的几个方面:
[1] 将其改为值对象。值对象是状态不可变的对象;对于同样的值应当返回同一个实例。值对象需要覆写equals 和hashCode方法;而如何使得对于同一个值参数返回同一个实例呢?参考Integer.ValueOf()实现,定义了一个内部嵌套类BinoTheoremCache,存放256个缓存BinomialTheorem对象。若取BinomialTheorem(i), 0<=i < 256 , 那么,直接从缓存中取;否则,使用 new来获取。这也说明一点:如果不太清楚某某怎么实现的,可以参考JDK源码来获取思路和启示。
[2] 大整数: 计算到 C(60,30) 就发现整型不够用了。n= 60 就无法获得正确结果,这个类是毫无用处的。必须使用大整数;这里直接使用了java提供的 BigInteger. 具体实现是如何呢?留待以后推敲。这里仅仅给出猜想:应该是用整型数组拼接而成。int i可以表示 2147483647; 那么 int[]arr = new int[2] ; arr 可以表示到21474836472147483647这么大的数(转换为字符串形式)。每个整数用一个整型数组表示显然浪费空间;那么可以分段表示:若 i<= 2147483647 , 则用含一个整数的数组表示; 若 2147483647< I <= 21474836472147483647;则用含两个整数的数组表示。依次类推。接下来就需要实现数组的加法、减法、乘法等。
[3] 空间效率: 如果采用初始的a[0:n][0:n]那么,其空间效率是O(n^2);当 n= 10000 时;int[0:10000][0:10000]需要 10000*10000*4/1024/1024= 381.47MB 空间;对于n= 100000 就无能为力了;因此,如果总是囿于小数据处理,那么,会有一种“取之不尽,用之不竭”的错觉;一旦深入到大数据集的处理领域,就不得不经常面对JavaOutOfMemoryError 了。
如何改进其空间效率呢?思路是直观的:重用空间。可以发现:A. 考虑到对称性,C(n,k)= C(n, n-k) , 实际上只需要a[0:n][0:n/2]; B. 当矩阵a[j-1][i]用来计算 a[j][i] i = 0,1,..., n/2 ; 1 <=j <= n之后, a[j-1][i]便不再起作用;因此,可以重用a[j-1][i], i= 0,1,..., n/2 的空间; 于是,只需要a[0:1][0:n/2]就可以了; 进一步地, 可以只需要a[0:n/2]空间 ,不过必须要倒着计算:
假设a[0: 2] 存储着C(4,0) , C(4,1), C(4,2) ; 即:
a[0]= C(4,0), a[1] = C(4,1) , a[2] = C(4,2) ; 则
S1:a[2] = C(5,2) = 2C(4,2) = 2 *a[2];
S2:a[1] = C(5,1) = C(4,0) + C(4,1) = a[0] + a[1];
S3:a[0] = C(5,0) = 1;
如果顺着计算,会产生覆盖:
S1:a[0] = C(5,0) =1;
S2:a[1] = C(5,1) =C(4,0) +C(4,1)
S3: a[2] = C(5,2) = 2C(4,1) //! a[1]已被覆盖为C(5,1)
这里实际上说明了一个比较普遍的现象:虽然动态规划法通常牺牲一定的空间来换取时间效率;但空间效率通常是有一定的提升和优化的空间的。
[4] 一个简单的运行时间测量框架
测量方法运行时间是一个较为频繁的操作,尤其是当实现一定的算法,并期望知道其性能的时候。通过写一个比较简单的运行时间测量框架,可以方便以后的算法性能测量应用。目前还只能接受单个问题规模参数的测试。这与程序测试等其实都是一本万利的事情。最初,可能觉得很麻烦,一段时间熟悉后,当你掌握相关方法和技术后,开发速度就自然提上去了。熟能生巧。
[5]小结:
不得不说,在这个代码重审和改进的过程中,确实学到了不少东西,也体验到了不断精益求精的一些感受,很好很充实。一个类,要写好可不容易!当然,改写后的最终程序并不一定就有多完善,不过,实用性是有很大提升的。程序中有不足之处,还恳请读者指正。
Java 代码实现:
BinomialTheorem.java
package algorithm.dynamicplan; import java.math.BigInteger; /** * BinomialTheorem * 计算二项式系数及展开式 * */ public class BinomialTheorem { /** 二项式的幂 */ private final int power; /** 二项式展开式的系数向量,binomialCoeffs[i] 存储 C(power,i) */ private BigInteger[] binomialCoeffs; // 标记: 如果已经计算过该对象的二项式系数向量,则不必再重新计算 private boolean flag = false; public BinomialTheorem(int power) { this.power = power; if (binomialCoeffs == null) { binomialCoeffs = new BigInteger[power/2+1]; } } public static BinomialTheorem getInstance(int power) { if (power < 0) { throw new IllegalArgumentException("参数错误,指定二项式的幂必须为正整数!"); } if (power < 256) { return BinoTheoremCache.cache[power]; } else { return new BinomialTheorem(power); } } private static class BinoTheoremCache { private BinoTheoremCache() { } private static BinomialTheorem[] cache = new BinomialTheorem[256]; static { for (int i=0; i < 256; i++) { cache[i] = new BinomialTheorem(i); } } } /** * getBinomial : 获得二项式展开式的字符串表达形式 * (a+b)^n = a^n + C(n,1)*a^(n-1)*b + ... + C(n,k)*a^(n-k)*b^k + C(n,n-1)*a*b^(n-1) + b^n */ public String toString() { if (flag == false) { // 计算二项式展开式的系数向量,且仅仅计算一次 calcBinomialCoeffs(); } if (power == 0) { return "(a+b)^0 = 1"; } String beginString = "(a+b)" + "^" + power + " = /n"; StringBuilder result = new StringBuilder(beginString); for (int i = 0; i <= power; i++) { int equivIndex = (i <= power/2) ? i : (power-i); // 衔接二项式系数 C(n,i) : 若 C(n,i) = 1 则省略不显示; 若 i > power/2 , 则 i = power - power/2 . result.append((binomialCoeffs[equivIndex].compareTo(BigInteger.valueOf(1)) == 0) ? "" : binomialCoeffs[equivIndex]); result.append(displayTerm("a", power-i)); // 衔接 * a^(n-i) result.append(displayTerm("b",i)); // 衔接 * b^i result.append(" + "); if (i % 10 == 9) { result.append('/n'); } } result.deleteCharAt(result.length()-2); return result.toString(); } /* * displayTerm : 显示二项式的项 term^power * 若 power = 0 ,则不显示该项 ; 若 power = 1, 则只显示 term * 若 power > 1 , 则显示 term^power* */ private String displayTerm(String term, int power) { if (power == 0) { return "" ; } if (power == 1) { return term; } return "(" + term + "^" + power + ")"; } /** * combinNum : 计算组合数的便利方法 * @return 组合数 C(n,k) */ public static BigInteger combinNum(int n, int k) { if (n < 0 || k < 0 || n < k) throw new IllegalArgumentException(); int finalk = (k<=n/2) ? k : (n-k); BigInteger[] coeffs = new BigInteger[finalk+1]; for (int i=0; i < coeffs.length; i++) { coeffs[i] = BigInteger.valueOf(0); } calcBinomialCoeffs(coeffs, n); return coeffs[finalk]; } /** * binomialCoeff: 计算二项式展开的系数 C(n,0) --- C(n,n) * C(n,k) = C(n,n-k) * C(n,k) = C(n-1,k) + C(n-1, k-1) * C(n,0) = C(n,n) = 1 * */ public void calcBinomialCoeffs() { for (int i=0; i < binomialCoeffs.length; i++) { binomialCoeffs[i] = BigInteger.valueOf(0); } calcBinomialCoeffs(this.binomialCoeffs, power); flag = true; } /* * calcBinomialCoeff: 计算二项式展开的系数 * C(n,k) = C(n,n-k) * C(n,k) = C(n-1,k) + C(n-1, k-1) * C(n,0) = C(n,n) = 1 * */ private static void calcBinomialCoeffs(BigInteger[] binomialCoeffs, int powerNum) { // binomialCoeffs : 存储 C(powerNum, j) = C(powerNum-1 , j-1) + C(powerNum-1, j) for (int i = 0; i <= powerNum; i++) { int upperIndex = Math.min(i/2, binomialCoeffs.length-1); for (int k = upperIndex; k >= 0; k--) { if (k == 0 || i == k) binomialCoeffs[k] = BigInteger.valueOf(1); else if (2*k == i) { binomialCoeffs[k] = binomialCoeffs[k-1].multiply(BigInteger.valueOf(2)); } else { binomialCoeffs[k] = binomialCoeffs[k-1].add(binomialCoeffs[k]); } } } } public boolean equals(Object obj) { if (!(obj instanceof BinomialTheorem)) { return false; } return ((BinomialTheorem)obj).power == power; } public int hashCode() { return power; } }
一个简单的运行时间测试框架 RuntimeMeasurement.java
package common; import java.lang.reflect.Constructor; import java.lang.reflect.Method; public class RuntimeMeasurement { public RuntimeMeasurement(int maxsize) { this.maxsize = maxsize; time = new double[maxsize]; } // 问题最大规模: 以 10 的 size 次幂计 private int maxsize ; // 运行时间以 ms 计 private double[] time ; /** * measureTime : 对指定类型的对象调用指定参数列表的指定方法,并测量其运行时间 * @param type 指定对象类型,必须有一个 参数类型为 int 的公共构造器方法 * @param methodName 指定测试方法名称,要求是空参数列表 */ public void measureTime(Class<?> type, String methodName) { try { Constructor<?> con = type.getConstructor(int.class); Method testMethod = null; for (int i = 0; i < time.length; i++) { Object obj = con.newInstance(power10(i+1)); testMethod = type.getMethod(methodName, new Class<?>[]{}); long start = System.nanoTime(); testMethod.invoke(obj, new Object[] {}); long end = System.nanoTime(); time[i] = ((end - start) / (double)1000000) ; } } catch (Exception e) { e.printStackTrace(); System.out.println(e.getMessage()); } } /** * showTime : 显示已经测量获得的运行时间,在 measureTime 方法调用后调用该方法。 */ public void showTime() { for (int i=0; i < time.length; i++) { System.out.printf("n = %12d : " , power10(i+1)); System.out.printf("%12.3f/n", time[i]); } } private int power10(int n) { int result = 1; while (n > 0) { result *= 10; n--; } return result; } }
BinomialTheorem 测试类 : BinomialTest.java
package algorithm.dynamicplan; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import common.RuntimeMeasurement; public class BinomialTest { public static void autoTest() { for (int i = 0; i < 10; i++) { try { System.out.println(BinomialTheorem.getInstance(i)); } catch (Exception e) { e.printStackTrace(); System.out.println(e.getMessage()); } } } public static void measureTime() { RuntimeMeasurement rm = new RuntimeMeasurement(4); rm.measureTime(BinomialTheorem.class, "calcBinomialCoeffs"); rm.showTime(); } public static void handTest() { try { BufferedReader stdin = new BufferedReader(new InputStreamReader(System.in)); System.out.printf("请输入任意数字: "); String param = ""; try { while ((param = stdin.readLine()) != null && param.length() != 0) { if (! param.matches("//s*0|[1-9][0-9]*//s*")) throw new NumberFormatException(); BinomialTheorem bino = BinomialTheorem.getInstance(Integer.parseInt(param)); System.out.println(bino); } } finally { stdin.close(); } } catch (IOException ioe) { ioe.printStackTrace(); System.out.println("IO 出错: " + ioe.getMessage()); } catch(NumberFormatException nfe) { nfe.printStackTrace(); System.out.println("输入有误,请按格式输入:" + nfe.getMessage()); } catch (Exception e) { e.printStackTrace(); } } public static void testCombinNum() { for (int k = 1; k <= 20; k++) { System.out.printf("C(%d,%d) = %d/n", 20, k, BinomialTheorem.combinNum(20, k)); } for (int n = 10; n <= 100000000; n *= 10) { System.out.printf("C(%d,%d) = %d/n", n, 10, BinomialTheorem.combinNum(n, 10)); } for (int k=10; k <= 100; k += 10) { System.out.printf("C(%d,%d) = %d/n", k, k/2, BinomialTheorem.combinNum(k, k/2)); } } public static void testValueObject() { BinomialTheorem bt3 = BinomialTheorem.getInstance(3); BinomialTheorem another3 = BinomialTheorem.getInstance(3); System.out.println("bt3 hashCode: " + bt3.hashCode()); System.out.println("another3 hashCode: " + another3.hashCode()); System.out.println("another3 == bt3 ? " + (another3 == bt3)); System.out.println("another3.equals(bt3) ? " + another3.equals(bt3)); Integer int100 = Integer.valueOf(100); Integer ano100 = Integer.valueOf(100); System.out.println("int100 hashCode: " + int100.hashCode()); System.out.println("ano100 hashCode: " + ano100.hashCode()); System.out.println("int100 == ano100 ? " + (int100 == ano100)); System.out.println("int100.equals(ano100) ? " + int100.equals(ano100)); String string = "am I happy ? "; String anoString = "am I happy ? "; System.out.println("string hashCode: " + string.hashCode()); System.out.println(" anoString hashCode: " + anoString.hashCode()); System.out.println("string == anoString ? " + (string == anoString)); System.out.println("string.equals(anoString) ? " + string.equals(anoString)); } public static void main(String[] args) { System.out.println("默认 JVM 最大内存: " + Runtime.getRuntime().maxMemory()); System.out.println("********** 值对象性质检验 ************"); testValueObject(); System.out.println("--------- 求二项式的幂的展开式: ------------ "); System.out.println("********** 自动测试小实例 ************"); autoTest(); System.out.println("********** 手动测试实例 ************"); handTest(); System.out.println("******* 计算二项式系数的运行时间测试(ms) *********"); measureTime(); System.out.println("********** 计算组合数 ************"); testCombinNum(); System.out.println(" JVM 已占用总内存: " + Runtime.getRuntime().totalMemory()); System.out.println(" JVM 可用内存: " + Runtime.getRuntime().freeMemory()); } }