HDU 5793 A Boring Question (逆元+快速幂+费马小定理) ---2016杭电多校联合第六场
A Boring Question
Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65536/65536 K (Java/Others)
Total Submission(s): 156 Accepted Submission(s): 72
Problem Description
Input
The first line of the input contains the only integer T,(1≤T≤10000)
Then T lines follow,the i-th line contains two integers n,m,(0≤n≤109,2≤m≤109)
Then T lines follow,the i-th line contains two integers n,m,(0≤n≤109,2≤m≤109)
Output
For each n and m,output the answer in a single line.
Sample Input
2
1 2
2 3
Sample Output
3
13
Author
UESTC
Source
Recommend
wange2014
题意:有m个数,求c(k2,k1)*c(k3,k2).....*c(km,km-1)的值,每个数的值在[0,n]之间,求所有情况的值的总和。
题解:首先看题目中的条件 And (kj+1kj)=0 while kj+1<kj 所以我们只需要考虑非递减序列即可. 也就是说把该问题转化为n的阶乘除以组成n的m个数的各自阶乘的积,首先进行打表:
#include<iostream> #include<cstring> #include<cstdio> using namespace std; int n,m; long long C[1005][1005]; long long mod = 1000000007; long long cal(int cur,int pre) { if(cur==m+1) return 1; long long ans = 0; for(int i=pre;i<=n;i++) { //printf("%lld\n",C[i][pre]); ans+=C[i][pre]*cal(cur+1,i)%mod; ans%=mod; } return ans; } int main() { C[0][0] = 1; C[1][0]=C[1][1]=1; for(int i=2;i<=1000;i++) { C[i][0] = 1; C[i][i] = 1; for(int j=1;j<i;j++) { C[i][j] = (C[i-1][j] + C[i-1][j-1])%mod; } } int data[10][10]; memset(data,0,sizeof(data)); for(int j=2;j<=5;j++) { for(int i=0;i<=5;i++) { n=i,m=j; printf("n: %d m: %d ",n,m); printf("%lld\n",cal(1,0)); } } while(scanf("%d%d",&n,&m)!=EOF) { printf("%lld\n",cal(1,0)); } }
运行结果如下:
仔细观察结果我们可以发现,这是等比数列前n项和,即m^ 0+m^1+m^ 2+m^3+.....+m ^n=(m^(n+1)-1)/(m-1);答案是对mod=1e9+7取模的,我们知道mod是一个素数,且m的范围是int,所以gcd(m,mod)=1; 所以满足费马小定理的条件,根据费马小定理我们得知分母m-1对mod的逆元为(m-1)^(mod-2); ans=(m^(n+1)-1)%mod*(m-1)^(mod-2)%mod;利用快速幂即可求出结果。
AC代码:
#include <iostream> #include <cstdio> #include <cmath> #include <cstring> using namespace std; typedef long long ll; const int mod=1000000007; ll pow1(ll a,ll b) { ll ans=1; while(b) { if(b&1) { ans=ans*a%mod; } b>>=1; a=a*a%mod; } return ans; } int main() { int t; ll n,m; cin>>t; while(t--) { cin>>n>>m; ll ans1,ans2,ans3; ans1=(pow1(m,n+1)-1+mod)%mod; ans2=(pow1(m-1,mod-2)+mod)%mod; ans3=ans1*ans2%mod; //cout<<pow1(2,4)<<endl; //cout<<ans1<<endl<<ans2<<endl; cout<<ans3<<endl; } return 0; }
官方给出的公式推导表示没看懂。