Luogu4491 [HAOI2018]染色 【容斥原理】【NTT】
题目分析:
一开始以为是直接用指数型生成函数,后来发现复杂度不对,想了一下容斥的方法。
对于有$i$种颜色恰好出现$s$次的情况,利用容斥原理得到方案数为
$$\binom{m}{i}\frac{P_{is}^{n}}{(s!)^i}(\sum_{j=0}^{m-i}(-1)^j\binom{m-i}{j}\frac{P_{js}^{n-is}}{(s!)^j}(m-i-j)^{n-is-js})$$
值得注意的是$n-is-js<0$的时候,后面的式子直接等于$0$,特判一下就行了。
那么答案就等于
$$\sum_{i=0}^{m}w_i\binom{m}{i}\frac{P_{is}^{n}}{(s!)^i}(\sum_{j=0}^{m-i}(-1)^j\binom{m-i}{j}\frac{P_{js}^{n-is}}{(s!)^j}(m-i-j)^{n-is-js})$$
式子看着很长,但其实没啥味道,把组合数和排列数展开,常数项提出来,约分,可以得到上面的式子等价于
$$(n!)*(m!)*\sum_{i=0}^{m}\frac{w_i}{(i!)(s!)^i}(\sum_{j=0}^{m-i}\frac{(-1)^j(m-i-j)^{n-is-js}}{(m-i-j)!j!(n-is-js)!(s!)^j})$$
对于后面的那个求和,使用肉眼观察法,会发现是个关于$j$和$m-i-j$的卷积。因为$m-i-j$的值确定了就意味着$n-is-js$的值也确定了。所以NTT搞出来
时间复杂度$O(nlogn)$
代码:
1 #include<bits/stdc++.h> 2 using namespace std; 3 4 const int maxn = 202000; 5 const int mod = 1004535809; 6 const int gg = 3; 7 8 int n,m,s; 9 int w[maxn]; 10 11 int fac[10200000],A[maxn*4],B[maxn*4]; 12 13 int fast_pow(int now,int pw){ 14 int ans = 1,dt = now,bit = 1; 15 while(bit <= pw){ 16 if(bit &pw){ans = 1ll*ans*dt%mod;} 17 dt = 1ll*dt*dt%mod; bit<<=1; 18 } 19 return ans; 20 } 21 22 void buildfunc(){ 23 fac[0] = 1; 24 for(int i=1;i<=max(n,m);i++) fac[i] = 1ll*fac[i-1]*i%mod; 25 for(int i=0;i<=m;i++){ 26 A[i] = 1ll*fast_pow(fac[s],i)*fac[i]%mod; 27 A[i] = fast_pow(A[i],mod-2); 28 if(i&1) A[i] = 1ll*(mod-1)*A[i]%mod; 29 } 30 for(int i=0;i<=m;i++){ 31 int z = m-i; 32 if(n-z*s < 0) {B[i] = 0;continue;} 33 int rem = n-z*s; 34 B[i] = 1ll*fac[i]*fac[rem]%mod; 35 B[i] = fast_pow(B[i],mod-2); 36 B[i] = 1ll*B[i]*fast_pow(i,rem)%mod; 37 } 38 } 39 40 int ord[maxn*4]; 41 42 void NTT(int *d,int len,int dr){ 43 for(int i=0;i<len;i++) if(ord[i] < i) swap(d[i],d[ord[i]]); 44 for(int i=1;i<len;i<<=1){ 45 int w = fast_pow(gg,(mod-1)/(2*i)); 46 if(dr == -1) w = fast_pow(w,mod-2); 47 for(int j=0;j<len;j+=(i<<1)){ 48 for(int k=0,wn=1;k<i;k++,wn = 1ll*wn*w%mod){ 49 int x = d[j+k],y = 1ll*wn*d[j+k+i]%mod; 50 d[j+k] = (x+y)%mod; 51 d[j+k+i] = (x-y+mod)%mod; 52 } 53 } 54 } 55 if(dr == -1){ 56 int iv = fast_pow(len,mod-2); 57 for(int i=0;i<len;i++){d[i] = 1ll*d[i]*iv%mod;} 58 } 59 } 60 61 void work(){ 62 buildfunc(); 63 /*int reans = 0; 64 for(int i=0;i<=m;i++){ 65 int z = 1ll*w[i]*fast_pow(1ll*fac[i]*fast_pow(fac[s],i)%mod,mod-2)%mod; 66 int np = 0,kp = 0; 67 for(int j=0;j<m-i;j++){ 68 if(n-i*s-j*s < 0) continue; 69 int mp = 0; 70 mp = 1ll*fac[m-i-j]*fac[j]%mod*fac[n-i*s-j*s]%mod*fast_pow(fac[s],j)%mod; 71 mp = 1ll*fast_pow(mp,mod-2)*fast_pow(m-i-j,n-i*s-j*s)%mod; 72 if(j & 1) mp = 1ll*(mod-1)*mp%mod; 73 kp += 1ll*A[j]*B[m-i-j]%mod; 74 kp %= mod; 75 np += mp; 76 np %= mod; 77 } 78 reans += 1ll*z*np%mod; 79 reans %= mod; 80 } 81 reans = 1ll*reans*fac[n]%mod*fac[m]%mod; 82 printf("%d\n",reans);return;*/ 83 84 85 int hk = 1,pi = 0; while(hk <= m+m) hk*=2,pi++; 86 for(int i=0;i<hk;i++) ord[i] = (ord[i>>1]>>1) + ((i&1)<<(pi-1)); 87 NTT(A,hk,1); NTT(B,hk,1); 88 for(int i=0;i<hk;i++) A[i] = 1ll*A[i]*B[i]%mod; 89 NTT(A,hk,-1); 90 int ans = 0; 91 for(int i=0;i<=m;i++){ 92 int z = 1ll*fac[m-i]*fast_pow(fac[s],m-i)%mod; 93 z = 1ll*fast_pow(z,mod-2)*w[m-i]%mod; 94 ans += 1ll*z*A[i]%mod; 95 ans %= mod; 96 } 97 ans = 1ll*ans*fac[n]%mod*fac[m]%mod; 98 printf("%d\n",ans); 99 } 100 101 int main(){ 102 scanf("%d%d%d",&n,&m,&s); 103 for(int i=0;i<=m;i++) scanf("%d",&w[i]); 104 work(); 105 return 0; 106 }