计算组合数最大的困难在于数据的溢出,对于大于150的整数n求阶乘很容易超出double类型的范围,那么当C(n,m)中的n=200时,直接用组合公式计算基本就无望了。另外一个难点就是效率。

    对于第一个数据溢出的问题,可以这样解决。因为组合数公式为:

    C(n,m) = n!/(m!(n-m)!)

为了避免直接计算n的阶乘,对公式两边取对数,于是得到:

    ln(C(n,m)) = ln(n!)-ln(m!)-ln((n-m)!)

进一步化简得到:

    大数量级组合数的快速计算方法

这样我们就把连乘转换为了连加,因为ln(n)总是很小的,所以上式很难出现数据溢出。

    为了解决第二个效率的问题,我们对上式再做一步化简。上式已经把连乘法变成了求和的线性运算,也就是说,上式已经极大地简化了计算的复杂度,但是还可以进一步优化。从上式中,我们很容易看出右边的3项必然存在重复的部分。现在我们把右边第一项拆成两部分:

    大数量级组合数的快速计算方法

这样,上式右边第一项就可以被抵消掉,于是得到:

    大数量级组合数的快速计算方法

上式直接减少了2m次对数计算及求和运算。但是这个公式还可以优化。对于上面公式里的求和,当m<n/2时,n-m是一个很大的数,但是当m>n/2时,n-m就会小很多。我们知道:

    C(n,m) = C(n,n-m)

那么通过这个公式,我们可以把小于n/2的m变为大于n/2的n-m再进行计算,结果是一样的,但是却能减少计算量。

    当计算出ln(C(n,m))后,只需要取自然对数,就可以得到组合数:

    C(n,m) = exp(ln(C(n,m)))

这样就完成了组合数的计算。

    用这种方法计算组合数,如果只计算ln(C(n,m))的话,n可以取到整型数据的极限值65535,

    ln(C(65535,32767)) = 45419.6

而计算时间只需要0.01ms。当然,如果要取对数得到最终的组合数的话,n的取值就不能达到这么大了。但是这种算法仍然可以保证n取到1000以上,而不是开头说的150这个极限值。例如:

    C(1000,500) = 2.70288e+299

计算时间仍然小于0.01ms。

    采用我这种算法,不仅n的取值范围大,而且计算速度高,不像用递归算法实现这个问题的时候,很容易陷入递归层次太深而导致计算时间太长。

    算法代码实现如下:

double lnchoose(int n, int m)
{

if (m > n)

{

return 0;

}
if (m < n/2.0)
{
m = n-m;
}

double s1 = 0;
for (int i=m+1; i<=n; i++)
{
s1 += log((double)i);
}

double s2 = 0;
int ub = n-m;
for (int i=2; i<=ub; i++)
{
s2 += log((double)i);
}

return s1-s2;
}

double choose(int n, int m)
{

if (m > n)

{

return 0;

}
return exp(lnchoose(n, m));
}

摘自:http://blog.sina.com.cn/s/blog_4298002e0100eko0.html

有一个用欧几里得扩展算法计算的大数取模计算方法,比较实用

#include<cstdio>
 #include<memory>
 using namespace std;
 const int mod=10007;
 int a[mod];
 void init()
 {
 int i;
     a[0]=1;
 for(i=1;i<mod;i++)
    a[i]=(a[i-1]*i)%mod;
 }
 int gcd(int a,int b){
 if(b==0) return a;
 return gcd(b,a%b);
 }
void e_gcd(int a,int b,int &x,int &y) //扩展欧几里得定理:解ax+by==1。
 {
     if(!b)
     {
         x=1;
         y=0;
     }
     else
     {
         e_gcd(b,a%b,x,y);
         int l=x;
         x=y;
         y=l-a/b*y;
     }
 }
int choose(int n,int m)  
 {
     if(m>n)
    return 0;
 else if(n==m)
    return 1;
 int nn=a[n],mm=(a[m]*a[n-m])%mod;
 int d=gcd(nn,mm);
 nn/=d;
 mm/=d;
     int x,y;
     e_gcd(mm,mod,x,y);
 x=(x+mod)%mod;
     return (x*nn)%mod;
 }
 int main( )
 {
 int t;
 scanf("%d",&t);
 init();
 while(t--)
 {
    int e[100],f[100];
    int i=0,j,m,n;
    memset(e,0,sizeof(e));
    memset(f,0,sizeof(f));
    scanf("%d %d",&n,&m);
    while(n>0)
    {
     e[i++]=n%mod;
     n=n/mod;
    }
    int len=i;
    i=0;
    while(m>0)
    {
            f[i++]=m%mod;
      m=m/mod;
    }
    int re=1;
         for(i=0;i<len;i++)
    {
            re=(re*choose(e[i],f[i]))%mod;
    }
    printf("%d\n",re%mod);
 }
 return 0;
 }

/********************************************/

posted on 2011-09-30 13:15  geeker  阅读(3633)  评论(0编辑  收藏  举报