洛谷P5206 [WC2019] 数树(生成函数+容斥+矩阵树)
题面
前置芝士
矩阵树,基本容斥原理,生成函数,多项式\(\exp\)
题解
我也想哭了……orz rqy,orz shadowice
我们设\(T1,T2\)为两棵树,并定义一个权值函数\(w(T1,T2)=y^{n-|T1\cap T2|}\),其中\(|T1\cap T2|\)为两棵树共同拥有的边的数目
显然,\(w(T1,T2)\)就是两棵树在该情况下的方案个数,因为\(T1\cap T2\)后的图中每个连通块只能用同一种颜色,而\(n-|T1\cap T2|\)就是连通块个数
子任务\(0\)就是给出\(T1,T2\),求\(w(T1,T2)\)
子任务\(1\)就是给定\(T1\),对所有的\(T2\)求和
子任务\(2\)就是对所有的\(T1,T2\)求和
子任务\(0\)
就是问有多少条边重合,直接暴力跑一下就行了
子任务\(1\)
\(y^{n-|T1\cap T2|}\)太麻烦了,把它化成\(y^{-|T1\cap T2|}\),最后再把答案乘上一个\(y^n\)就行了
我们现在已经知道了\(T1\),我们假设以下设计的所有集合都是\(T1\)的子集。设一个\(F(S)\)表示\(T1\cap T2=S\)的所有\(w(T1,T2)\)的权值之和,那么答案就是\(\sum_{S} F(S)\)。显然\(F(S)=G(S)y^{-|S|}\),其中\(G(S)\)表示\(T1\cap T2=S\)的\(T2\)的个数
然后我们再设一个\(A(S)=B(S)y^{-|S|}\),其中\(B(S)\)表示含有\(S\)这个边集的树的个数
首先我们很容易写出一个容斥式子
我们在两边同乘上\(y\),可以变成
然后我们再来看一看答案
然后令\(p={1\over y}-1\),那么式子就变成了
然而它的复杂度仍然是指数级的还是没用,还得继续
考虑\(C(T)\),表示至少含有\(T\)这个边集的树的个数,设这个边集有\(k\)个连通块,第\(i\)个连通块中有\(a_i\)个点,那么可以知道
那么我们来手屠矩阵吧
先来考虑一下\(C(T)\)该怎么计算。我们把所有连通块中的点缩到一起,对于两个连通块\(i,j\),在两个点之间连接\(a_ia_j\)条重边。那么这个新的图的矩阵树求出的答案就是\(C(T)\)了
那么这样建出来图的基尔霍夫矩阵应该长这样
然后我们删去一行一列之后会变成这样
接下来我们将第 \(i\) 行除去 \(a_{i}\) 这样最终的行列式需要乘上一个 \(\prod_{i=1}^{k-1}a_{i}\) ,矩阵会变成这样
然后我们将第2列到第k-1列加到第1列上,会得到
接下来我们用第1行去减其他的行,会得到
这样矩阵就被我们削成了一个上三角阵,它的行列式是 \(n^{k-2}a_{k}\) 乘上 \(\prod_{i=1}^{k-1}a_{i}\) 就是
这样我们就证明了我们的式子是正确的
那么式子可以继续展开
令\(k={n\over p}\),问题可以转化成如下形式:一个连通块权值为\(k\)乘上这个连通块的大小,一个边集\(T\)的权值等于其中所有连通块的权值之积,求所有边集的权值之和
设\(f(i,j)\)表示考虑到\(i\)这棵子树,其中\(i\)所在的连通块大小为\(j\),总的答案是多少。不过因为这里的权值并不包含\(i\)所在的连通块,所以最终的答案还要算上\(i\)的连通块的贡献,那么答案就是
然而这样的复杂度是\(O(n^2)\),还是要\(T\),得继续优化
我们设一个生成函数
发现这个生成函数有一个特点,就是\(f'_u(1)\)恰好就是\(\sum_{1=1}^n f(u,j)j\),其中前者表示对\(f_u(x)\)求导之后用\(1\)代入\(x\)。那么最后的答案就是\(kf'_1(1)\)
考虑转移,为
设\(g_u=f'_u(1)\)
如果设\(h_u=f_u(1)\),那么转移式子可以写成
边界条件为\(h_u=1,g_u=k\)
那么就可以\(O(n)\)树形\(dp\)了
最后把输出\({g_1p^ny^n\over n^2}\)
子任务\(2\)
我们发现其实子任务\(1\)里的容斥依然可以用在这里,唯一的区别就是我们要把\(C(T)\)改成\(C^2(T)\),那么这里\(C^2(T)\)就表示交集至少为\(T\)的二元组\(T1,T2\)的对数。不过注意这里要枚举的是所有边的子集\(T\)
那么答案就变成了
然后设\(k={n^2\over p}\),问题就变成了:一个大小为\(i\)的树的权值为\(ki^2\),一个森林的权值为所有树得权值之积,求所有森林的权值之和
我们先考虑树的权值,大小为\(i\)的无根树总共有\(i^{i-2}\)个,每一个贡献为\(ki^2\),那么大小为\(i\)的树的权值总和就是\(ki^i\),它的指数型生成函数就是
对于每一个森林,都是由若干棵树构成的,且这些树之间是无序的。那么根据指数型生成函数的意义,我们可以得知森林权值的指数型生成函数为
那么只要做一次多项式\(exp\)就行了,然后取第\(n\)项,把答案乘上\(\frac{p^nbas^{n}n!}{n^4}\)
以上是题解,以下是吐槽
这题知识点还真多……多项式板子打错之后只能干瞪眼果然不是瞎吹的……用了一天看懂题解,用了一天把代码调出来……这才是真正的数数题么……
//minamoto
#include<bits/stdc++.h>
#define R register
#define fp(i,a,b) for(R int i=a,I=b+1;i<I;++i)
#define fd(i,a,b) for(R int i=a,I=b-1;i>I;--i)
#define go(u) for(int i=head[u],v=e[i].v;i;i=e[i].nx,v=e[i].v)
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
int read(){
R int res,f=1;R char ch;
while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
return res*f;
}
const int N=1e5+5,P=998244353,Gi=332748118;
inline int add(R int x,R int y){return x+y>=P?x+y-P:x+y;}
inline int dec(R int x,R int y){return x-y<0?x-y+P:x-y;}
inline int mul(R int x,R int y){return 1ll*x*y-1ll*x*y/P*P;}
int ksm(R int x,R int y){
R int res=1;
for(;y;y>>=1,x=mul(x,x))if(y&1)res=mul(res,x);
return res;
}
int n,bas,op;
namespace solver0{
struct node{
int u,v;
node(R int uu=0,R int vv=0){u=min(uu,vv),v=max(uu,vv);}
inline bool operator <(const node &b)const{return u==b.u?v<b.v:u<b.u;}
inline bool operator ==(const node &b)const{return u==b.u&&v==b.v;}
}e[N<<1];int tot,u,v,res;
void MAIN(){
fp(i,1,n-1)u=read(),v=read(),e[++tot]=node(u,v);
fp(i,1,n-1)u=read(),v=read(),e[++tot]=node(u,v);
sort(e+1,e+1+tot),tot=unique(e+1,e+1+tot)-e-1;
res=(n-1<<1)-tot,res=n-res,printf("%d\n",ksm(bas,res));
}
}
namespace solver1{
struct eg{int v,nx;}e[N<<1];int head[N],tot;
inline void add_edge(R int u,R int v){e[++tot]={v,head[u]},head[u]=tot;}
int f[N],g[N],u,v,res,p,k;
void dfs(int u,int fa){
f[u]=1,g[u]=k;
go(u)if(v!=fa){
dfs(v,u),p=add(f[v],g[v]);
g[u]=(1ll*g[u]*p+1ll*g[v]*f[u])%P;
f[u]=mul(f[u],p);
}
}
void MAIN(){
if(bas==1)return printf("%d\n",ksm(n,n-2)),void();
res=dec(ksm(bas,P-2),1),k=mul(n,ksm(res,P-2));
fp(i,1,n-1)u=read(),v=read(),add_edge(u,v),add_edge(v,u);
dfs(1,0);
printf("%d\n",1ll*g[1]*ksm(mul(res,bas),n)%P*ksm(mul(n,n),P-2)%P);
}
}
namespace solver2{
const int N=5e5+5;
int r[N],O[N],inv[N],fac[N],ifac[N],f[N],g[N],l,lim,res;
void init(R int len){
lim=1,l=0;while(lim<len)lim<<=1,++l;
fp(i,0,lim-1)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
}
void NTT(int *A,int ty,int len=0){
fp(i,0,lim-1)if(i<r[i])swap(A[i],A[r[i]]);
for(R int mid=1;mid<lim;mid<<=1){
int I=(mid<<1),Wn=ksm(ty==1?3:Gi,(P-1)/I);O[0]=1;
fp(i,1,mid-1)O[i]=mul(O[i-1],Wn);
for(R int j=0;j<lim;j+=I)fp(k,0,mid-1){
int x=A[j+k],y=mul(O[k],A[j+k+mid]);
A[j+k]=add(x,y),A[j+k+mid]=dec(x,y);
}
}
if(ty==-1)for(R int i=0,inv=ksm(lim,P-2);i<lim;++i)A[i]=mul(A[i],inv);
}
void Inv(int *a,int *b,int len){
if(len==1)return b[0]=ksm(a[0],P-2),void();
Inv(a,b,len>>1);
static int A[N],B[N];init(len<<1);
fp(i,0,len-1)A[i]=a[i],B[i]=b[i];
fp(i,len,lim-1)A[i]=B[i]=0;
NTT(A,1),NTT(B,1);
fp(i,0,lim-1)A[i]=mul(A[i],mul(B[i],B[i]));
NTT(A,-1);
fp(i,0,len-1)b[i]=dec(add(b[i],b[i]),A[i]);
fp(i,len,lim-1)b[i]=0;
}
void Ln(int *a,int *b,int len){
static int A[N],B[N];
fp(i,1,len-1)A[i-1]=mul(a[i],i);A[len-1]=0;
Inv(a,B,len);init(len<<1);
fp(i,len,lim-1)A[i]=B[i]=0;
NTT(A,1),NTT(B,1);
fp(i,0,lim-1)A[i]=mul(A[i],B[i]);
NTT(A,-1);
fp(i,1,len-1)b[i]=mul(A[i-1],inv[i]);b[0]=0;
fp(i,len,lim-1)b[i]=0;
}
void Exp(int *a,int *b,int len){
if(len==1)return b[0]=1,void();
Exp(a,b,len>>1);
static int A[N];
Ln(b,A,len);init(len<<1);
A[0]=dec(a[0]+1,A[0]);
fp(i,1,len-1)A[i]=dec(a[i],A[i]);
fp(i,len,lim-1)A[i]=b[i]=0;
NTT(A,1),NTT(b,1);
fp(i,0,lim-1)b[i]=mul(b[i],A[i]);
NTT(b,-1);
fp(i,len,lim-1)b[i]=0;
}
void Pre(int len){
inv[0]=inv[1]=1;fp(i,2,len)inv[i]=mul(P-P/i,inv[P%i]);
fac[0]=fac[1]=ifac[0]=1;fp(i,2,len)fac[i]=mul(fac[i-1],i);
ifac[len]=ksm(fac[len],P-2);fd(i,len-1,1)ifac[i]=mul(ifac[i+1],i+1);
}
void MAIN(){
if(bas==1)return printf("%d\n",mul(ksm(n,n-2),ksm(n,n-2))),void();
int len=1;while(len<=n)len<<=1;Pre(len);
int p=dec(ksm(bas,P-2),1),k=1ll*n*n%P*ksm(p,P-2)%P;
fp(i,1,len-1)f[i]=1ll*k*ksm(i,i)%P*ifac[i]%P;
Exp(f,g,len);
res=1ll*g[n]*fac[n]%P*ksm(p,n)%P*ksm(bas,n)%P*ksm(1ll*n*n%P*n%P*n%P,P-2)%P;
printf("%d\n",res);
}
}
int main(){
// freopen("testdata.in","r",stdin);
n=read(),bas=read(),op=read();
switch(op){
case 0:solver0::MAIN();break;
case 1:solver1::MAIN();break;
case 2:solver2::MAIN();break;
}
return 0;
}