UOJ#449. 【集训队作业2018】喂鸽子 min-max容斥,FFT
原文链接www.cnblogs.com/zhouzhendong/p/UOJ449.html
题解
设 f(i) 表示给 i 只鸽子喂食使得至少一只鸽子被喂饱的期望次数,先 min-max容斥 一下。($\frac ni$ 表示期望每 $\frac ni$ 步喂这 i 只鸽子一次)
$$ans = \sum_{i=1}^n (-1)^{i+1}\binom ni \frac ni \cdot f(i)$$
考虑如何求 f(i) 。假设我们喂饱的是第一只鸽子,那么假设我们喂了其他鸽子 j 次,那么就可以得到以下式子:
$$f(i) = \sum_{j=0}^{\infty} (j+k) \binom {j+k-1}{k-1} \cdot \left ( g^{i-1} \right ) ^{(j)} (0)\cdot \frac 1{i^{j+k}}$$
(注: $h^{(a)}(x)$ 表示函数 $h(x)$ 的 a 阶导数,$h^{(a)}(0)$ 表示指数生成函数 $h$ 的第 a 项系数)
其中 $\left(g^{i-1}\right)^{(j)}(0)$ 表示给 i-1 只鸽子喂食,每只喂的次数不超过 k-1 次,总共喂了 j 次的方案数。由于还有一只要强制喂到 k 次,所以要乘上 $\binom{j+k-1}{k-1}$ ,这种情况下喂了 $j+k$ 次鸽子,所以要乘上 $j+k$。
那么这个 g(x) 是什么东西?
对于一只鸽子,可以喂 $0,1,2,\cdots, k-1$ 次,搞一个指数生成函数就好了。
$$ g(x) = \sum_{i=0}^{k-1} \frac{ x^i} {i!}$$
时间复杂度 $O(n^2k \log (nk))$ 。
好像还有一个 $O(n^2k)$ 的神仙做法,先坑着。
代码
#pragma GCC optimize("Ofast","inline") #include <bits/stdc++.h> #define clr(x) memset(x,0,sizeof (x)) #define For(i,a,b) for (int i=a;i<=b;i++) #define Fod(i,b,a) for (int i=b;i>=a;i--) #define pb(x) push_back(x) #define mp(x,y) make_pair(x,y) #define fi first #define se second #define _SEED_ ('C'+'L'+'Y'+'A'+'K'+'I'+'O'+'I') #define outval(x) printf(#x" = %d\n",x) #define outvec(x) printf("vec "#x" = ");for (auto _v : x)printf("%d ",_v);puts("") #define outtag(x) puts("----------"#x"----------") #define outarr(a,L,R) printf(#a"[%d...%d] = ",L,R);\ For(_v2,L,R)printf("%d ",a[_v2]);puts(""); using namespace std; typedef long long LL; typedef unsigned long long ULL; typedef vector <int> vi; LL read(){ LL x=0,f=0; char ch=getchar(); while (!isdigit(ch)) f|=ch=='-',ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return f?-x:x; } const int N=55,K=1005,S=1<<16,mod=998244353; void Add(int &x,int y){ if ((x+=y)>=mod) x-=mod; } void Del(int &x,int y){ if ((x-=y)<0) x+=mod; } int del(int x,int y){ return x-y<0?x-y+mod:x-y; } int Pow(int x,int y){ int ans=1; for (;y;y>>=1,x=(LL)x*x%mod) if (y&1) ans=(LL)ans*x%mod; return ans; } int Fac[S],Inv[S]; void prework(){ int n=S-1; for (int i=Fac[0]=1;i<=n;i++) Fac[i]=(LL)Fac[i-1]*i%mod; Inv[n]=Pow(Fac[n],mod-2); Fod(i,n,1) Inv[i-1]=(LL)Inv[i]*i%mod; } int C(int n,int m){ if (m<0||m>n) return 0; return (LL)Fac[n]*Inv[m]%mod*Inv[n-m]%mod; } int n,k; int m,d,invm; int f[N]; int R[S],w[S]; int a[S],b[S],c[S]; void FFT(int *a,int n){ For(i,0,m-1) if (i<R[i]) swap(a[i],a[R[i]]); for (int t=n>>1,d=1;d<n;d<<=1,t>>=1) for (int i=0;i<n;i+=d<<1) for (int j=0;j<d;j++){ int tmp=(LL)w[t*j]*a[i+j+d]%mod; a[i+j+d]=del(a[i+j],tmp); Add(a[i+j],tmp); } } int main(){ prework(); n=read(),k=read(); for (m=1,d=0;m<n*k;m<<=1,d++); invm=Pow(m,mod-2); For(i,0,m-1) R[i]=(R[i>>1]>>1)|((i&1)<<(d-1)); w[0]=1,w[1]=Pow(3,(mod-1)/m); For(i,2,m-1) w[i]=(LL)w[i-1]*w[1]%mod; clr(a); For(i,0,k-1) a[i]=Inv[i]; FFT(a,m); For(i,0,m-1) b[i]=1; For(x,1,n){ For(i,0,m-1) c[i]=b[i]; reverse(w+1,w+m); FFT(c,m); reverse(w+1,w+m); For(i,0,m-1) c[i]=(LL)c[i]*invm%mod*Fac[i]%mod; f[x]=0; For(i,0,m-1) if (c[i]) Add(f[x],(LL)C(i+k-1,k-1)*c[i]%mod*(i+k)%mod*Pow(x,mod-i-k)%mod); For(i,0,m-1) b[i]=(LL)b[i]*a[i]%mod; } int ans=0; For(i,1,n){ int tmp=(LL)C(n,i)*n%mod*Pow(i,mod-2)%mod*f[i]%mod; if (i&1) Add(ans,tmp); else Del(ans,tmp); } cout<<ans<<endl; return 0; }