【xsy2479】counting 生成函数+多项式快速幂
题目大意:在字符集大小为$m$的情况下,有多少种构造长度为$n$的字符串$s$的方案,使得$C(s)=k$。其中$C(s)$表示字符串$s$中出现次数最多的字符的出现次数。
对$998244353$取模,$n,m≤5\times 10^4$
如果你考虑去DP,你就lose了。
令$F(x)$表示满足$C(s)≤x$的方案数。
那么最终的答案显然为$F(k)-F(k-1)$。
这一题有一个非常优美的性质:对于每一种字符,允许的最多出现次数都是$k$。
那么,令$G_k(x)=\sum\limits_{i=0}^{k} \frac{1}{i!}x^i$
则有$F(k)=n![x^n]G_k^m(x)$
证明是显然的
写一个多项式快速幂的板子就过了。
1 #include<bits/stdc++.h> 2 #define M (1<<17) 3 #define L long long 4 #define MOD 998244353 5 #define G 3 6 using namespace std; 7 8 L pow_mod(L x,L k){ 9 L ans=1; 10 while(k){ 11 if(k&1) ans=ans*x%MOD; 12 x=x*x%MOD; k>>=1; 13 } 14 return ans; 15 } 16 17 void change(L a[],int n){ 18 for(int i=0,j=0;i<n-1;i++){ 19 if(i<j) swap(a[i],a[j]); 20 int k=n>>1; 21 while(j>=k) j-=k,k>>=1; 22 j+=k; 23 } 24 } 25 void NTT(L a[],int n,int on){ 26 change(a,n); 27 for(int h=2;h<=n;h<<=1){ 28 L wn=pow_mod(G,(MOD-1)/h); 29 for(int j=0;j<n;j+=h){ 30 L w=1; 31 for(int k=j;k<j+(h>>1);k++){ 32 L u=a[k],t=w*a[k+(h>>1)]%MOD; 33 a[k]=(u+t)%MOD; 34 a[k+(h>>1)]=(u-t+MOD)%MOD; 35 w=w*wn%MOD; 36 } 37 } 38 } 39 if(on==-1){ 40 L inv=pow_mod(n,MOD-2); 41 for(int i=0;i<n;i++) a[i]=a[i]*inv%MOD; 42 reverse(a+1,a+n); 43 } 44 } 45 46 void getinv(L a[],L b[],int n){ 47 if(n==1){b[0]=pow_mod(a[0],MOD-2); return;} 48 static L c[M],d[M]; 49 memset(c,0,n<<4); memset(d,0,n<<4); 50 getinv(a,c,n>>1); 51 for(int i=0;i<n;i++) d[i]=a[i]; 52 NTT(d,n<<1,1); NTT(c,n<<1,1); 53 for(int i=0;i<(n<<1);i++) b[i]=(2*c[i]-d[i]*c[i]%MOD*c[i]%MOD+MOD)%MOD; 54 NTT(b,n<<1,-1); 55 for(int i=0;i<n;i++) b[n+i]=0; 56 } 57 58 void qiudao(L a[],L b[],int n){ 59 memset(b,0,sizeof(b)); 60 for(int i=1;i<n;i++) b[i-1]=i*a[i]%MOD; 61 } 62 void jifen(L a[],L b[],int n){ 63 memset(b,0,sizeof(b)); 64 for(int i=0;i<n;i++) b[i+1]=a[i]*pow_mod(i+1,MOD-2)%MOD; 65 } 66 67 void getln(L a[],L b[],int n){ 68 static L c[M],d[M]; 69 memset(c,0,n<<4); memset(d,0,n<<4); 70 qiudao(a,c,n); getinv(a,d,n); 71 NTT(c,n<<1,1); NTT(d,n<<1,1); 72 for(int i=0;i<(n<<1);i++) c[i]=c[i]*d[i]%MOD; 73 NTT(c,n<<1,-1); 74 jifen(c,b,n); 75 } 76 77 void getexp(L a[],L b[],int n){ 78 if(n==1){b[0]=1; return;} 79 static L lnb[M]; memset(lnb,0,n<<4); 80 getexp(a,b,n>>1); getln(b,lnb,n); 81 for(int i=0;i<n;i++) lnb[i]=(a[i]-lnb[i]+MOD)%MOD,b[i+n]=0; 82 lnb[n]=0; 83 lnb[0]=(lnb[0]+1)%MOD; 84 NTT(lnb,n<<1,1); NTT(b,n<<1,1); 85 for(int i=0;i<(n<<1);i++) b[i]=b[i]*lnb[i]%MOD; 86 NTT(b,n<<1,-1); 87 for(int i=0;i<n;i++) b[i+n]=0; 88 } 89 90 L a[M]={0},b[M]={0}; 91 L fac[M]={0},invfac[M]={0}; 92 int n,k,m; 93 94 L solve(){ 95 memset(a,0,sizeof(a)); 96 memset(b,0,sizeof(b)); 97 int nn=1; while(nn<=n) nn<<=1; 98 for(int i=0;i<=m;i++) a[i]=invfac[i]; 99 L hh=a[0],invhh=pow_mod(hh,MOD-2); 100 for(int i=0;i<nn;i++) a[i]=a[i]*invhh%MOD; 101 getln(a,b,nn); 102 for(int i=0;i<nn;i++) b[i]=b[i]*k%MOD; 103 getexp(b,a,nn); 104 hh=pow_mod(hh,k); 105 for(int i=0;i<nn;i++) a[i]=a[i]*hh%MOD; 106 return a[n]; 107 } 108 109 int main(){ 110 scanf("%d%d%d",&n,&k,&m); 111 fac[0]=1; for(int i=1;i<M;i++) fac[i]=fac[i-1]*i%MOD; 112 invfac[M-1]=pow_mod(fac[M-1],MOD-2); 113 for(int i=M-2;~i;i--) invfac[i]=invfac[i+1]*(i+1)%MOD; 114 L res1=solve(); 115 m--; 116 L res2=solve(); 117 cout<<(res1-res2+MOD)*fac[n]%MOD<<endl; 118 }