[HAOI2018]染色(NTT)
前置芝士
前置定义
\[\begin{aligned}\\
f_i=C_m^i\cdot \frac{n!}{(S!)^i(n-iS)!}\cdot (m-i)^{n-iS}\\
ans_i=\sum\limits_{j=i}^lim (-1)^{j-i}C_j^i f_j\\
\end{aligned}\]
理解:\(m\)种颜色选i种恰好出现\(S\)次,可重全排列,剩余块染色,不过这样有可能会出现剩余块种有恰好出现\(S\)次的情况,所以容斥一下
\(C_j^i\):\(f_j\)里一定包含着\(f_i\),要减掉,同时有可能减掉了在原本\(j+1..\)的东西
推式
\[\begin{aligned}\\
ans_i=\sum\limits_{j=i}^{lim} (-1)^{j-i}\frac{j!}{i!(j-i)!}f_j\\
ans_i\cdot i!=\sum\limits_{j=i}^{lim}(\frac{(-1)^{j-i}}{(j-i)!})\cdot (f_j\cdot j!)\\
\end{aligned}\]
设生产函数\(G,F\)分别对应\((\frac{(-1)^{j-i}}{(j-i)!}),(f_j\cdot j!)\),再把\(F\)翻转一下:
\[\begin{aligned}\\
ans_i\cdot i!&=\sum\limits_{j=i}^{lim}G_{j-i}\cdot F_{lim-j}\\
H&=G*F\\
ans_i\cdot i!&=H_{lim-i}\\
\end{aligned}\]
Code
上\(NTT\)模板就行
#include<bits/stdc++.h>
typedef long long LL;
const LL mod=1004535809,gg=3,maxn=1e7+9;
inline LL Read(){
LL x(0),f(1); char c=getchar();
while(c<'0' || c>'9'){
if(c=='-') f=-1; c=getchar();
}
while(c>='0' && c<='9'){
x=(x<<3)+(x<<1)+c-'0'; c=getchar();
}
return x*f;
}
inline LL Pow(LL base,LL b){
LL ret(1);
while(b){
if(b&1) ret=ret*base%mod; base=base*base%mod; b>>=1;
}return ret;
}
LL fac[maxn],fav[maxn],r[maxn];
inline LL Get_c(int n,int m){
return fac[n]*fav[m]%mod*fav[n-m]%mod;
}
inline LL Fir(LL n){
LL limit(1),len(0);
while(limit<(n<<1)){
limit<<=1; ++len;
}
for(int i=0;i<limit;++i) r[i]=(r[i>>1]>>1)|((i&1)<<len-1);
return limit;
}
inline void NTT(LL *a,int n,int type){
for(int i=0;i<n;++i) if(i<r[i]) std::swap(a[i],a[r[i]]);
for(LL mid=1;mid<n;mid<<=1){
LL wn(Pow(gg,(mod-1)/(mid<<1)));
if(type==-1) wn=Pow(wn,mod-2);
for(LL R=mid<<1,j=0;j<n;j+=R){
for(LL k=0,w=1;k<mid;++k,w=w*wn%mod){
LL x(a[j+k]),y(a[j+mid+k]*w%mod);
a[j+k]=(x+y)%mod; a[j+mid+k]=(x-y+mod)%mod;
}
}
}
if(type==-1){
LL ty(Pow(n,mod-2));
for(int i=0;i<n;++i) a[i]=a[i]*ty%mod;
}
}
LL n,m,S,lim,ret;
LL W[maxn],f[maxn],g[maxn],h[maxn],ans[maxn];
int main(){
n=Read(); m=Read(); S=Read();
for(int i=0;i<=m;++i) W[i]=Read();
lim=std::min(m,n/S);
fac[0]=fac[1]=1;
int up(std::max(n,m));
for(int i=2;i<=up;++i)
fac[i]=fac[i-1]*i%mod;
fav[up]=Pow(fac[up],mod-2);
for(int i=up;i>=1;--i)
fav[i-1]=fav[i]*i%mod;
for(int i=0;i<=lim;++i)
f[i]=Get_c(m,i)*fac[n]%mod* Pow(Pow(fac[S],i),mod-2)%mod *fav[n-i*S]%mod *Pow(m-i,n-i*S)%mod *fac[i]%mod;
for(int i=0;i<=(lim>>1);++i)
std::swap(f[i],f[lim-i]);
for(int i=0;i<=lim;++i)
g[i]=(Pow(-1,i)*fav[i]+mod)%mod;
LL limit(Fir(lim+1));
NTT(f,limit,1); NTT(g,limit,1);
for(int i=0;i<limit;++i) h[i]=g[i]*f[i]%mod;
NTT(h,limit,-1);
for(int i=0;i<=lim;++i) ans[i]=h[lim-i]*fav[i]%mod;
for(int i=0;i<=lim;++i) ret=(ret+ans[i]*W[i]%mod)%mod;
printf("%lld\n",ret);
return 0;
}