题解 uoj390 百鸽笼
Description
Solution
首先把选择还有空的鸽笼改为可以选择已满的鸽笼 , 但不产生任何效果 .
把最后一个空着的位置改为最后一个被填的位置 .
考虑计算第 \(i\) 列的答案 .
在上面的转化下 , 第 \(i\) 列的答案等价于所有满足下列条件的选择序列的概率和
- 出现了 \(a_i\) 个 \(i\)
- 对于 \(j\neq i\) , 出现了至少 \(a_j\) 个 \(j\)
- 序列末尾为 \(i\) .
如果存在一个这样的长度为 \(L\) 的序列 , 那么该序列对答案的贡献就是 \(\displaystyle(\frac{1}{n})^L\) .
考虑答案的 EGF.
设最后 \(j\) 在序列中出现了 \(t\) 次 , 那么它对 EGF 的贡献为 \(\displaystyle G(t)=\frac{x^t}{n^tt!}\),\(t!\) 为可重集排列 .
那么每个 \(j\neq i\) 在 EGF 中的贡献为 \(\displaystyle\sum\limits_{s=a_j}^{+\infty}G(s)\)
那么答案的 EGF 即为 \(\displaystyle\frac{G(a_i-1)}{n}\prod\limits_{j\neq i}\sum\limits_{s=a_j}^{+\infty}G(s)\)
发现 \(\displaystyle\sum\limits_{s=0}^{+\infty}G(s)=e^{\frac{x}{n}}\), 所以设 \(\displaystyle S(t)=\sum\limits_{s=0}^tG(s)\)
那么答案式即为 \(\displaystyle\frac{G(a_i-1)}{n}\prod\limits_{j\neq i}(e^{\frac{x}{n}}-S(a_j-1))\)
设 \(m=\max a\)
可以使用分治乘法暴力卷积 \(O(n^3m^2)\) 求出对于全部的 \(i\) 的上式 , 也可以 ntt \(O(n^3m\log(nm))\) , 但实测没有前一种快
那么考虑最后 \(\displaystyle e^{\frac{mx}{n}}x^t\) 的系数对答案的贡献系数 .
因为 \(\displaystyle e^{\frac{mx}{n}}=\sum\limits_{s\geq 0}\frac{m^sx^s}{n^ss!}\)
所以 \(\displaystyle e^{\frac{mx}{n}}x^t=\sum\limits_{s\geq 0}\frac{m^sx^{s+t}}{n^ss!}\)
所以贡献系数即为 \(\displaystyle\sum\limits_{s\geq 0}\frac{m^s(s+t)!}{n^ss!}\)
上式等于 \(\displaystyle(\frac{n}{n-m})^{t+1}t!\)
可以将 \(\displaystyle(\frac{1}{1-x})^{t+1}t!\) 在原点泰勒展开来证明 .
然后把所有的系数乘上对应的贡献系数就是答案 .
时间复杂度 \(O(n^3m^2)\)
Code
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
#include<cmath>
typedef long long ll;
using namespace std;
int read()
{
int ret=0;char c=getchar();
while(c>'9'||c<'0')c=getchar();
while(c>='0'&&c<='9')ret=(ret<<3)+(ret<<1)+(c^48),c=getchar();
return ret;
}
const int maxn=35;
const int maxm=1005;
const int mod=998244353;
int n,a[maxn],sum;
int ans[maxn];
int qpow(int a,int b){int ret=1;for(;b;b>>=1,a=(ll)a*a%mod)if(b&1)ret=(ll)ret*a%mod;return ret;}
int R[1<<10],W[1<<10];
struct poly
{
vector<int>v;
int& operator [](int i){return v[i];}
void ntt(int L,int typ)
{
int n=pow(2,L);
for(int i=0;i<n;i++)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
W[0]=1;W[1]=qpow(3,(mod-1)/n);if(typ==-1)W[1]=qpow(W[1],mod-2);
for(int i=2;i<n;i++)W[i]=(ll)W[i-1]*W[1]%mod;
if(v.size()<n)v.resize(n,0);
for(int i=0;i<n;i++)if(R[i]>i)swap(v[i],v[R[i]]);
for(int t=n>>1,d=1;d<n;d<<=1,t>>=1)
for(int i=0;i<n;i+=(d<<1))
for(int j=0;j<d;j++)
{
int tmp=(ll)W[t*j]*v[i+j+d]%mod;
v[i+j+d]=(v[i+j]-tmp+mod)%mod;
v[i+j]=(v[i+j]+tmp)%mod;
}
if(typ==-1){int inv=qpow(n,mod-2);for(int i=0;i<n;i++)v[i]=(ll)v[i]*inv%mod;}
}
void operator *=(poly &x)
{
int L=ceil(log2(v.size()+x.v.size())),n=pow(2,L);
ntt(L,1);x.ntt(L,1);
for(int i=0;i<n;i++)v[i]=(ll)v[i]*x[i]%mod;
ntt(L,-1);x.ntt(L,-1);
while(v.size()>1&&v.back()==0)v.pop_back();
while(x.v.size()>1&&x.v.back()==0)x.v.pop_back();
}
void operator +=(poly &x)
{
if(v.size()<x.v.size())v.resize(x.v.size(),0);
for(int i=0;i<x.v.size();i++)v[i]=(v[i]+x[i])%mod;
}
};
int fac[maxm],ifac[maxm],in[maxm],pown[maxm],ipoww[maxn][maxm];
poly s[maxn];
void prework()
{
fac[0]=1;for(int i=1;i<=sum;i++)fac[i]=(ll)fac[i-1]*i%mod;
ifac[0]=ifac[1]=1;for(int i=2;i<=sum;i++)ifac[i]=(ll)(mod-mod/i)*ifac[mod%i]%mod;
for(int i=2;i<=sum;i++)ifac[i]=(ll)ifac[i-1]*ifac[i]%mod;
in[0]=1;in[1]=qpow(n,mod-2);for(int i=2;i<=sum;i++)in[i]=(ll)in[i-1]*in[1]%mod;
pown[0]=1;for(int i=1;i<=sum;i++)pown[i]=(ll)pown[i-1]*n%mod;
for(int j=0;j<=sum;j++)ipoww[1][j]=1;
for(int i=2;i<=n;i++)
{
ipoww[i][0]=1;ipoww[i][1]=(ll)(mod-mod/i)*ipoww[mod%i][1]%mod;
for(int j=2;j<=sum;j++)ipoww[i][j]=(ll)ipoww[i][j-1]*ipoww[i][1]%mod;
for(int j=0;j<=sum;j++)assert((ll)ipoww[i][j]*qpow(i,j)%mod==1);
}
for(int i=1;i<=n;i++)
{
s[i].v.resize(a[i]);
for(int j=0;j<a[i];j++)s[i][j]=mod-(ll)in[j]*ifac[j]%mod;
}
}
poly ret[maxn];
void solve(int l,int r)
{
if(l==r)
{
poly g;g.v.resize(a[l]);g.v[a[l]-1]=(ll)in[a[l]]*ifac[a[l]-1]%mod;
for(int i=0;i<=n-1;i++)ret[i]*=g;
for(int i=0;i<=n-1;i++)for(int j=0;j<ret[i].v.size();j++)ans[l]=(ans[l]+(ll)ret[i][j]*pown[j+1]%mod*ipoww[n-i][j+1]%mod*fac[j]%mod)%mod;
return;
}
int mid=(l+r)>>1;
poly tmp[maxn];
for(int i=0;i<=n-1;i++)tmp[i]=ret[i];
for(int j=mid+1;j<=r;j++)for(int k=n-1;k>=0;k--){ret[k]*=s[j];if(k)ret[k]+=ret[k-1];}
solve(l,mid);
for(int i=0;i<=n-1;i++)ret[i]=tmp[i];
for(int j=l;j<=mid;j++)for(int k=n-1;k>=0;k--){ret[k]*=s[j];if(k)ret[k]+=ret[k-1];}
solve(mid+1,r);
}
int main()
{
generate_n(a+1,n=read(),read);
for(int i=1;i<=n;i++)sum+=a[i];
prework();
ret[0].v.resize(1,1);
solve(1,n);
for(int i=1;i<=n;i++)printf("%d ",ans[i]);
return 0;
}