AT2070 Card Game for Three(组合数学)
解题思路
前面的思路还是很好想的,就是要枚举最后一个\(a\)在哪出现算贡献,之后我先想的容斥,结果彻底偏了。。后来调了很久发现自己傻逼了,似乎不能容斥,终于走上正轨23333。首先可以写出一个\(O(n^2)\)的玩意,就是
\[ans=\sum\limits_{i=n}^{sum}C(n-1,i-1)*3^{sum-i}*\sum\limits_{j=max(0,i-n-m)}^{min(k,i-n)}C(i-n,j)
\]
这个式子就是枚举\(a\)的最后一个位置\(i\),然后从前面的\(n-1\)里选\(i-1\)个\(a\),再从剩下的里面选出\(j\)个\(c\),而\(j\)还要满足不能使\(b\)超过\(m\),不能使\(c\)超过\(k\),最后后面剩下的位置随意放。这样是\(O(n^2)\)的,优化也很简单,就是发现后面的一坨对应杨辉三角同行连续一段,然后每次\(i+1\)时转移到下一行连续一段,这个随便讨论一下即可。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
typedef long long LL;
const int N=1000005;
const int MOD=1e9+7;
int n,m,k,ans,fac[N],inv[N],sum;
inline int fast_pow(int x,int y){
int ret=1;
for(;y;y>>=1){
if(y&1) ret=1ll*ret*x%MOD;
x=1ll*x*x%MOD;
}
return ret;
}
inline int C(int x,int y){
return 1ll*fac[x]*inv[y]%MOD*inv[x-y]%MOD;
}
signed main(){
scanf("%d%d%d",&n,&m,&k);
fac[0]=inv[0]=1; sum=n+m+k;
for(int i=1;i<=sum;i++) fac[i]=1ll*fac[i-1]*i%MOD;
inv[sum]=fast_pow(fac[sum],MOD-2);
for(int i=sum-1;i;i--) inv[i]=1ll*inv[i+1]*(i+1)%MOD;
// for(int i=n;i<=sum;i++){
// int now=0;
// for(int j=max(i-n-m,1ll*0);j<=min(k,i-n);j++)
// (now+=C(i-n,j))%=MOD;
// ans+=now%MOD*C(i-1,n-1)%MOD*fast_pow(3,sum-i)%MOD;
// ans%=MOD;
// }
// cout<<ans<<endl; ans=0;
if(m<k) swap(m,k); int now=1;
for(int i=n;i<=sum;i++){
if(k+n>=i) {
(ans+=1ll*C(i-1,n-1)*fast_pow(3,sum-i)%MOD*now%MOD)%=MOD;
if(k+n!=i) now=now*2%MOD;
}
else if(n+m>=i) {
now=now*2-C(i-n-1,k);
now=(now%MOD+MOD)%MOD;
(ans+=1ll*now*C(i-1,n-1)%MOD*fast_pow(3,sum-i)%MOD)%=MOD;
}
else {
now=now*2-C(i-n-1,k); now=(now%MOD+MOD)%MOD;
now=(now-C(i-n-1,i-n-m-1)+MOD)%MOD;
(ans+=1ll*now*C(i-1,n-1)%MOD*fast_pow(3,sum-i)%MOD)%=MOD;
}
}
// for(int i=n;i<=sum;i++) {
// (ans+=1ll*C(i-1,n-1)*fast_pow(2,i-n)%MOD*fast_pow(3,sum-i)%MOD)%=MOD;
// if(i==8) cout<<ans<<endl;
// if(n+k+1<=i)
// (ans-=1ll*C(i-1,n-1)*C(i-n,k+1)%MOD*fast_pow(2,i-n-k-1)%MOD*fast_pow(3,sum-i)%MOD)%=MOD;
// if(i==8) cout<<ans<<endl;
// if(n+m+1<=i)
// (ans-=1ll*C(i-1,n-1)*C(i-n,m+1)%MOD*fast_pow(2,i-n-m-1)%MOD*fast_pow(3,sum-i)%MOD)%=MOD;
// cout<<ans<<endl;
// }
// for(int i=n+m;i<sum;i++){
// (ans-=1ll*now*fast_pow(3,sum-n-m)%MOD)%=MOD;
// now=1ll*now*i%MOD; now=(now+C(n+m,i-n-m+1))%MOD;
// }
// now=1;
// for(int i=n+k;i<sum;i++){
// (ans-=1ll*now*fast_pow(3,sum-n-k)%MOD)%=MOD;
// now=1ll*now*i%MOD; now=(now+C(n+k,i-n-k+1))%MOD;
// }
printf("%d\n",(ans+MOD)%MOD);
return 0;
}