洛谷 P4002 - [清华集训2017]生成树计数(多项式)
神题。
考虑将所有连通块缩成一个点,那么所有连好边的生成树在缩点之后一定是一个 \(n\) 个点的生成树。我们记 \(d_i\) 为第 \(i\) 个连通块缩完点之后的度数 \(-1\),那么共有 \(\prod\limits_{i=1}^na_i^{d_i+1}\times\dfrac{(n-2)!}{\prod\limits_{i=1}^nd_i!}\) 个这样的生成树,稍微解释一下这个柿子,因为每个连通块的每条边都有可能是由其中 \(a_i\) 个点中任意一点连出的,因此每个连通块连边的方案为 \(a_i^{d_i+1}\),而由 Prufer 序列那套理论可知 \(d_i\) 即为 \(i\) 在 Prufer 序列中的出现次数,且 \(\sum\limits_{i=1}^nd_i=n-2\),又因为一个 Prufer 序列与一棵生成树形成双射,故这样的 \(n\) 个点的生成树个数为满足 \(i\) 在序列中出现 \(d_i\) 次的长度为 \(n-2\) 的序列个数,即多重组合数 \(\dbinom{n-2}{d_1,d_2,\cdots,d_n}=\dfrac{(n-2)!}{\prod\limits_{i=1}^nd_i!}\),用乘法原理将它们乘起来即可。
因此我们要求的答案即为 \(\sum\limits_{d_1+d_2+\cdots+d_n=n-2}\prod\limits_{i=1}^na_i^{d_i+1}\times\dfrac{(n-2)!}{\prod\limits_{i=1}^nd_i!}\times(\prod\limits_{i=1}^n(d_i+1)^m)(\sum\limits_{i=1}^n(d_i+1)^m)\)。考虑将柿子稍微变个形,将 \(a_i^{d_i+1}\) 拆成 \(a_i^{d_i}\) 与 \(a_i\) 并将 \(a_i\) 提到外面来,再将 \((n-2)!\) 提到外面来,即可得到 \((n-2)!\prod\limits_{i=1}^na_i\sum\limits_{d_1+d_2+\cdots+d_n=n-2}\prod\limits_{i=1}^n\dfrac{a_i^{d_i}}{d_i!}\times(\prod\limits_{i=1}^n(d_i+1)^m)(\sum\limits_{i=1}^n(d_i+1)^m)\),将里面的两个 \(\prod\) 合并,又有 \((n-2)!\prod\limits_{i=1}^na_i\sum\limits_{d_1+d_2+\cdots+d_n=n-2}(\prod\limits_{i=1}^n\dfrac{a_i^{d_i}}{d_i!}(d_i+1)^m)(\sum\limits_{i=1}^n(d_i+1)^m)\),前面的 \((n-2)!\prod\limits_{i=1}^na_i\) 是常数,因此我们只需算出后面的答案后乘上去即可。将里面 \(\sum\) 的括号拆开,移到外面去,可得 \(\sum\limits_{d_1+d_2+\cdots+d_n=n-2}\sum\limits_{i=1}^n(d_i+1)^m(\prod\limits_{j=1}^n\dfrac{a_j^{d_j}}{d_j!}(d_j+1)^m)\),再稍微整理一下即可得到 \(\sum\limits_{d_1+d_2+\cdots+d_n=n-2}\sum\limits_{i=1}^n\dfrac{a_i^{d_i}}{d_i!}(d_i+1)^{2m}\prod\limits_{j\ne i}^n\dfrac{a_j^{d_j}}{d_j!}(d_j+1)^m\)。前面的我们 \(\sum\limits_{d_1+d_2+\cdots+d_n=n-2}\) 非常讨厌,不过长得一脸幂级数的样子,因此考虑记 \(A(x)=\sum\limits_{k\ge 0}\dfrac{x^k}{k!}(k+1)^m\),\(B(x)=\sum\limits_{k\ge 0}\dfrac{x^k}{k!}(k+1)^{2m}\),那么显然 \(\dfrac{a_i^{d_i}}{d_i!}(d_i+1)^{2m}=[x^{d_i}]B(a_ix)\),\(\dfrac{a_j^{d_j}}{d_j!}(d_j+1)^m=[x^{d_j}]A(a_jx)\),故 \(\sum\limits_{d_1+d_2+\cdots+d_n=n-2}\sum\limits_{i=1}^n\dfrac{a_i^{d_i}}{d_i!}(d_i+1)^{2m}\prod\limits_{j\ne i}^n\dfrac{a_j^{d_j}}{d_j!}(d_j+1)^m=\sum\limits_{d_1+d_2+\cdots+d_n=n-2}\sum\limits_{i=1}^n[x^{d_i}]B(a_ix)\prod\limits_{j\ne i}^n[x^{d_j}]A(a_jx)\),而显然它又等于 \([x^{n-2}]\sum\limits_{i=1}^nB(a_ix)\prod\limits_{j\ne i}^nA(a_jx)\),哦吼,\(\sum\limits_{d_1+d_2+\cdots+d_n=n-2}\),这下我们就可以直接从这两个幂级数入手了。
然鹅直接计算 \(\sum\limits_{i=1}^nB(a_ix)\prod\limits_{j\ne i}^nA(a_jx)\) 还是不太容易,考虑按照 P4389 付公主的背包 的套路取 \(\ln\) 再 \(\exp\) 回去,即 \(\sum\limits_{i=1}^nB(a_ix)\exp\sum\limits_{j\ne i}\ln A(a_jx)\),这里 \(j\ne i\) 不太美观,因此考虑直接把 \(j=i\) 的答案先统计进去,前面再除个 \(A(a_ix)\),即 \(\sum\limits_{i=1}^n\dfrac{B(a_ix)}{A(a_ix)}\exp\sum\limits_{j=1}^n\ln A(a_jx)\),这样两个 \(\sum\) 就可以独立计算了,故我们只需算出 \(\sum\limits_{i=1}^n\dfrac{B(a_ix)}{A(a_ix)}\) 和 \(\exp\sum\limits_{j=1}^n\ln A(a_jx)\),然后把它们卷起来即可,不过这里又有一个问题,就是对于所有 \(i\) 都要计算一遍 \(\dfrac{B(a_ix)}{A(a_ix)}\) 及 \(\ln A(a_jx)\),这样复杂度还是平方对数的,显然过不去。不过注意到 \([x^n]\dfrac{B(a_ix)}{A(a_ix)}=a_i^n[x^n]\dfrac{B(x)}{A(x)}\),\([x^n]\ln A(a_jx)=a_j^n[x^n]\ln A(x)\),因此我们只需预处理出 \(\dfrac{B(x)}{A(x)}\) 和 \(\ln A(x)\),那么 \(\sum\limits_{i=1}^n\dfrac{B(a_ix)}{A(a_ix)}\) 的第 \(k\) 项系数就是 \(\dfrac{B(x)}{A(x)}\) 第 \(k\) 项系数的 \(\sum\limits_{i=1}^na_i^k\) 倍,\(\sum\limits_{j=1}^n\ln A(a_jx)\) 也同理。
最后就是怎样对 \(k\in[0,n]\) 求 \(\sum\limits_{i=1}^na_i^k\),其实这个套路在 P4705 玩游戏 里就出现过了,不过这里再讲解一遍。由于 \(a^k=[x^k]\dfrac{1}{1-ax}\),因此这东西的生成函数 \(F(x)=\sum\limits_{i=1}^n\dfrac{1}{1-a_ix}\),注意到 \(\ln'(\dfrac{1}{1-a_ix})=-\dfrac{a_i}{1-a_ix}\),而 \(F(x)=\sum\limits_{i=1}^n\dfrac{1}{1-a_ix}=n-\sum\limits_{i=1}^n-\dfrac{a_i}{1-a_ix}\),故 \(F(x)=n-\sum\limits_{i=1}^n\ln'(\dfrac{1}{1-a_ix})=n+\sum\limits_{i=1}^n\ln'(1-a_ix)\),根据导数的和等于和的倒数可知,\(F(x)\) 又等于 \(n+(\sum\limits_{i=1}^n\ln(1-a_ix))'=n+(\ln\prod\limits_{i=1}^n(1-a_ix))'\),里面的 \(\prod\) 可以分治求出,复杂度 \(n\log^2n\),求出来之后再 \(\ln\) 一遍,\(\text{direv}\) 一遍即可。
时间复杂度 \(n\log^2n\)。
注意特判 \(n=1\),否则 UOJ 上过不去。
const int pr=3;
const int MAXN=3e4;
const int MAXP=1<<16;
const int MOD=998244353;
const int ipr=(MOD+1)/3;
int qpow(int x,int e=MOD-2){
int ret=1;
for(;e;e>>=1,x=1ll*x*x%MOD) if(e&1) ret=1ll*ret*x%MOD;
return ret;
}
int n,m,a[MAXN+5],fac[MAXN+5],ifac[MAXN+5];
void init_fac(int n){
fac[0]=ifac[0]=ifac[1]=1;
for(int i=2;i<=n;i++) ifac[i]=1ll*ifac[MOD%i]*(MOD-MOD/i)%MOD;
for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%MOD,ifac[i]=1ll*ifac[i]*ifac[i-1]%MOD;
}
int rev[MAXP+5],inv[MAXP+5];
void NTT(vector<int> &a,int len,int type){
int lg=31-__builtin_clz(len);
for(int i=0;i<len;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
for(int i=0;i<len;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i=2;i<=len;i<<=1){
int W=qpow((type<0)?ipr:pr,(MOD-1)/i);
for(int j=0;j<len;j+=i){
int w=1;
for(int k=0;k<(i>>1);k++,w=1ll*w*W%MOD){
int X=a[j+k],Y=1ll*w*a[(i>>1)+j+k]%MOD;
a[j+k]=(X+Y)%MOD;a[(i>>1)+j+k]=(X-Y+MOD)%MOD;
}
}
}
if(type==-1) for(int i=0;i<len;i++) a[i]=1ll*a[i]*inv[len]%MOD;
}
vector<int> conv(vector<int> a,vector<int> b){
int len=1;while(len<a.size()+b.size()) len<<=1;
a.resize(len,0);b.resize(len,0);NTT(a,len,1);NTT(b,len,1);
for(int i=0;i<len;i++) a[i]=1ll*a[i]*b[i]%MOD;
NTT(a,len,-1);return a;
}
vector<int> conv(vector<int> a,vector<int> b,int lenc){
int len=1;while(len<a.size()+b.size()) len<<=1;
a.resize(len,0);b.resize(len,0);NTT(a,len,1);NTT(b,len,1);
for(int i=0;i<len;i++) a[i]=1ll*a[i]*b[i]%MOD;
NTT(a,len,-1);if(a.size()>lenc) a.resize(lenc);
return a;
}
vector<int> getinv(vector<int> a,int n){
vector<int> b(n,0);assert(a[0]!=0);b[0]=qpow(a[0]);
for(int i=2;i<=n;i<<=1){
vector<int> c(a.begin(),a.begin()+i);
vector<int> d(b.begin(),b.begin()+(i>>1));
d=conv(d,d);c=conv(c,d);
for(int j=0;j<i;j++) b[j]=(2*b[j]%MOD-c[j]+MOD)%MOD;
} return b;
}
vector<int> direv(vector<int> a,int n){
vector<int> b(n,0);
for(int i=1;i<n;i++) b[i-1]=1ll*a[i]*i%MOD;
return b;
}
vector<int> inter(vector<int> a,int n){
vector<int> b(n,0);
for(int i=1;i<n;i++) b[i]=1ll*a[i-1]*inv[i]%MOD;
return b;
}
vector<int> getln(vector<int> a,int n){
vector<int> a_=direv(a,n),b=getinv(a,n);
b=conv(a_,b);b=inter(b,n);return b;
}
vector<int> getexp(vector<int> a,int n){
vector<int> b(n,0);b[0]=1;
for(int i=2;i<=n;i<<=1){
vector<int> c(b.begin(),b.begin()+i);
vector<int> d(b.begin(),b.begin()+(i>>1));
c=getln(c,i);for(int j=0;j<i;j++) c[j]=(a[j]-c[j]+MOD)%MOD;
(c[0]+=1)%=MOD;d=conv(d,c);
for(int j=0;j<i;j++) b[j]=d[j];
} return b;
}
int LEN=1;
vector<int> solve(int l,int r){
if(l==r){return vector<int>{1,MOD-a[l]};}
int mid=l+r>>1;return conv(solve(l,mid),solve(mid+1,r),LEN);
}
int main(){
scanf("%d%d",&n,&m);init_fac(n);
if(n==1) return printf("%d\n",m==0),0;
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
while(LEN<=n) LEN<<=1;
for(int i=1;i<=LEN<<1;i++) inv[i]=qpow(i);
vector<int> A(LEN),B(LEN);
for(int i=0;i<=n;i++) A[i]=1ll*ifac[i]*qpow(i+1,m)%MOD;
for(int i=0;i<=n;i++) B[i]=1ll*ifac[i]*qpow(i+1,2*m)%MOD;
vector<int> iA=getinv(A,LEN),lnA=getln(A,LEN);
vector<int> mA=conv(B,iA,LEN);
vector<int> C=solve(1,n);C.resize(LEN,0);
C=getln(C,LEN);C=direv(C,LEN);for(int j=LEN-1;j;j--) C[j]=(MOD-C[j-1])%MOD;
C[0]=n;//for(int i=1;i<=n;i++) printf("%d\n",C[i]);
for(int i=0;i<LEN;i++) lnA[i]=1ll*lnA[i]*C[i]%MOD;
for(int i=0;i<LEN;i++) mA[i]=1ll*mA[i]*C[i]%MOD;
lnA=getexp(lnA,LEN);vector<int> ret=conv(lnA,mA);
int ans=ret[n-2];//printf("%d\n",ans);
for(int i=1;i<=n;i++) ans=1ll*ans*a[i]%MOD;
ans=1ll*ans*fac[n-2]%MOD;
printf("%d\n",ans);
return 0;
}