题解 P4491 [HAOI2018]染色
题意
长度为\(N\)的序列, 每个位置都可以被染成\(M\)种颜色中的某一种.
如果恰 好出现了\(S\)次的颜色有\(K\)种, 会产生\(W_k\)的贡献.
对于所有可能的染色方案, 他能获得的愉悦度的和对 \(1004535809\)取模的结果是多少.
题解
最多有\(lim=\min(m,\frac{n}{s})\)种颜色。
记最少选\(i\)种的答案为\(F[i]\),有:
\[F[i]=\tbinom{m}{i}\frac{n!}{(s!)^i(n-is)!}(m-i)^{n-is}
\]
容斥一下得到答案为:
\[ans[i]=\sum_{j=i}^{lim}(-1)^{j-i}\tbinom{j}{i}F[j]
\]
拆开来:
\[ans[i]=\sum_{j=i}^{lim}(-1)^{j-i}\frac{j!}{i!(j-i)!}F[j]
\]
\[ans[i]\times i!=\sum_{j=i}^{lim}\frac{(-1)^{j-i}}{(j-i)!}\times F[j]j!
\]
记\(g[i]=\frac{(-1)^i}{i!},f[i]=i!\times F[i]\)
\[ans[i]\times i!=\sum_{j=i}^{lim}g[j-i]\times f[j]
\]
翻转\(f\),得到:
\[ans[i]\times i!=\sum_{j=i}^{lim}g[j-i]\times f[lim-j]
\]
卷积显而易见。
代码
#include<bits/stdc++.h>
namespace in{
char buf[1<<21],*p1=buf,*p2=buf;
inline int getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
template <typename T>inline void read(T& t){
t=0;int f=0;char ch=getc();while (!isdigit(ch)){if(ch=='-')f = 1;ch=getc();}
while(isdigit(ch)){t=t*10+(ch-48);ch = getc();}if(f)t=-t;
}
template <typename T,typename... Args> inline void read(T& t, Args&... args){read(t);read(args...);}
}
namespace out{
char buffer[1<<21];int p1=-1;const int p2 = (1<<21)-1;
inline void flush(){fwrite(buffer,1,p1+1,stdout),p1=-1;}
inline void putc(const char &x) {if(p1==p2)flush();buffer[++p1]=x;}
template <typename T>void write(T x) {
static char buf[15];static int len=-1;if(x>=0){do{buf[++len]=x%10+48,x/=10;}while (x);}else{putc('-');do {buf[++len]=-(x%10)+48,x/=10;}while(x);}
while (len>=0)putc(buf[len]),--len;
}
}
using namespace std;
template<const int mod>
struct modint{
int x;
modint<mod>(int o=0){x=o;}
modint<mod> &operator = (int o){return x=o,*this;}
modint<mod> &operator +=(modint<mod> o){return x=x+o.x>=mod?x+o.x-mod:x+o.x,*this;}
modint<mod> &operator -=(modint<mod> o){return x=x-o.x<0?x-o.x+mod:x-o.x,*this;}
modint<mod> &operator *=(modint<mod> o){return x=1ll*x*o.x%mod,*this;}
modint<mod> &operator ^=(int b){
modint<mod> a=*this,c=1;
for(;b;b>>=1,a*=a)if(b&1)c*=a;
return x=c.x,*this;
}
modint<mod> &operator /=(modint<mod> o){return *this *=o^=mod-2;}
modint<mod> &operator +=(int o){return x=x+o>=mod?x+o-mod:x+o,*this;}
modint<mod> &operator -=(int o){return x=x-o<0?x-o+mod:x-o,*this;}
modint<mod> &operator *=(int o){return x=1ll*x*o%mod,*this;}
modint<mod> &operator /=(int o){return *this *= ((modint<mod>(o))^=mod-2);}
template<class I>friend modint<mod> operator +(modint<mod> a,I b){return a+=b;}
template<class I>friend modint<mod> operator -(modint<mod> a,I b){return a-=b;}
template<class I>friend modint<mod> operator *(modint<mod> a,I b){return a*=b;}
template<class I>friend modint<mod> operator /(modint<mod> a,I b){return a/=b;}
friend modint<mod> operator ^(modint<mod> a,int b){return a^=b;}
friend bool operator ==(modint<mod> a,int b){return a.x==b;}
friend bool operator !=(modint<mod> a,int b){return a.x!=b;}
bool operator ! () {return !x;}
modint<mod> operator - () {return x?mod-x:0;}
modint<mod> &operator++(int){return *this+=1;}
};
const int N=1e7+5;
const int mod=1004535809;
const modint<mod> GG=3,Ginv=modint<mod>(1)/3;
struct poly{
vector<modint<mod>>a;
modint<mod>&operator[](int i){return a[i];}
int size(){return a.size();}
void resize(int n){a.resize(n);}
void reverse(){std::reverse(a.begin(),a.end());}
};
int rev[N];
inline int ext(int n){int k=0;while((1<<k)<n)k++;return k;}
inline void init(int k){int n=1<<k;for(int i=0;i<n;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));}
inline void ntt(poly&a,int k,int typ){
int n=1<<k;
for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int mid=1;mid<n;mid<<=1){
modint<mod> wn=(typ>0?GG:Ginv)^((mod-1)/(mid<<1));
for(int r=mid<<1,j=0;j<n;j+=r){
modint<mod> w=1;
for(int k=0;k<mid;k++,w=w*wn){
modint<mod> x=a[j+k],y=w*a[j+k+mid];
a[j+k]=x+y,a[j+k+mid]=x-y;
}
}
}
if(typ<0){
modint<mod> inv=modint<mod>(1)/n;
for(int i=0;i<n-1;i++)a[i]*=inv;
}
}
inline poly one(){poly a;a.a.push_back(1);return a;}
poly operator +(poly a,poly b){
int n=max(a.size(),b.size());a.resize(n),b.resize(n);
for(int i=0;i<n;i++)a[i]+=b[i];return a;
}
poly operator -(poly a,poly b){
int n=max(a.size(),b.size());a.resize(n),b.resize(n);
for(int i=0;i<n;i++)a[i]-=b[i];return a;
}
inline poly operator*(poly a,poly b){
int n=a.size()+b.size()-1,k=ext(n);
a.resize(1<<k),b.resize(1<<k),init(k);
ntt(a,k,1);ntt(b,k,1);for(int i=0;i<(1<<k);i++)a[i]*=b[i];
ntt(a,k,-1),a.resize(n);return a;
}
int n,m,s;int w[N];
modint<mod>fac[N];
poly ans,f,g;
modint<mod>C(int n,int m){if(n<m)return 0;return (fac[n]/fac[n-m])/fac[m];}
modint<mod>F(int i){return C(m,i)*(modint<mod>(m-i)^(n-i*s))*fac[n]/(((fac[s]^i)*fac[n-i*s]));}
signed main(){
freopen("color15.in","r",stdin);
fac[0]=1;for(int i=1;i<N;i++)fac[i]=fac[i-1]*i;
in::read(n,m,s);for(int i=0;i<=m;i++)in::read(w[i]);
int lim=min(m,n/s);f.resize(lim+1);g.resize(lim+1);
for(int i=0;i<=lim;i++)f[i]=fac[i]*F(i);f.reverse();
//for(int i=0;i<=f.size();i++)cout<<f[i].x<<" ";cout<<endl;
for(int i=0;i<=lim;i++)if(i&1)g[i]=-modint<mod>(1)/fac[i];else g[i]=modint<mod>(1)/fac[i];
ans=f*g;ans.resize(lim+1);ans.reverse();
//for(int i=0;i<ans.size();i++)cout<<ans[i].x<<" ";
modint<mod>Ans=0;for(int i=0;i<=lim;i++)Ans+=w[i]*ans[i]/fac[i];
out::write(Ans.x);
out::flush();
return 0;
}