LGP5206题解
\(opt=0\)
该部分过于简单不再阐述。
\(opt=1\)
设第一棵树的边集为 \(T1\),第二棵树的边集为 \(T2\)。答案是:
\[\sum_{T2}y^{n-|T1\cap T2|}
\]
\[\sum_{S}y^{n-|S|}\sum_{|T1\cap T2|=S}
\]
\[\sum_{|T1\cap T2|=S}y^{n-|S|}
\]
\[\sum_{S=|T1\cap T2|}\sum_{Q\subseteq S}\sum_{P\subseteq Q}(-1)^{|Q|-|P|}y^{n-|P|}
\]
\[\sum_{Q\subseteq T1,Q\subseteq T2}\sum_{P\subseteq Q}(-1)^{|Q|-|P|}y^{n-|P|}
\]
\[\sum_{Q\subseteq T1,P\subseteq T2}y^{n-|Q|}\sum_{P\subseteq Q}(-y)^{|Q|-|P|}
\]
\[\sum_{Q\subseteq T1,Q\subseteq T2}y^{n-|Q|}\sum_{i=0}^{|Q|}\binom{|Q|}{i}(-y)^i
\]
\[\sum_{Q\subseteq T1,Q\subseteq T2}y^{n-|Q|}(1-y)^{|Q|}
\]
\[y^n\sum_{Q\subseteq T1,Q\subseteq T2}(\frac{1-y}{y})^{|Q|}
\]
设 \(g(S)\) 为已经选择了 \(S\) 这个边集后,在 \(T2\) 上将联通块串起来的方案数。
\[y^n\sum_{Q\subseteq T1}g(Q)(\frac{1-y}{y})^{|Q|}
\]
\[y^n\sum_{S\subseteq T1}g(S)(\frac{1-y}{y})^{|S|}
\]
\[y^n\sum_{S\subseteq T1}\prod_{i=1}^{k}s_in^{k-2}(\frac{1-y}{y})^{|S|}
\]
\[(1-y)^{n}\sum_{S\subseteq T1}n^{k-2}(\frac{y}{1-y})^{k}\prod_{i=1}^{k}s_i
\]
\[\frac{(1-y)^n}{n^2}\sum_{S\subseteq T1}\prod_{i=1}^{k}\frac{ny}{1-y}s_i
\]
设 \(dp[u][k]\) 表示 \(u\) 子树中连通块大小为 \(k\) 时的权值之和。在最开始时加入自身,然后考虑加入一个儿子的时候该怎么做:
\[dp[u][k]=\sum_{i+j=k}dp[u][i]dp[v][j]
\]
\[dp[u][0]=\frac{ny}{1-y}\times\sum dp[u][i]i,dp[u][i]=dp[u][i-1]
\]
考虑多项式,设 \(F_u(x)=\sum dp[u][i]x^i\)
\[F_u(x)=F_u(x)(F_v(x)+\frac{ny}{1-y}\times F_v'(1))
\]
发现答案就是 \(k\times F_1'(1)\)。于是考虑对每个节点维护 \(F_u(1)\) 和 \(F_u'(1)\):
\[F_u(1)=F_u(1)(F_v(1)+\frac{ny}{1-y}\times F_v'(1))
\]
\[F_u'(1)=F_u(1)F_v'(1)+F_u'(1)(F_v(1)+\frac{ny}{1-y}\times F_v'(1))
\]
\(opt=2\)
\[y^n\sum_{Q\subseteq T1,Q\subseteq T2}(\frac{1-y}{y})^{|Q|}
\]
设 \(g(S)\) 为选择了 \(S\) 这个边集之后,在 \(T1\) 和 \(T2\) 上将联通块串起来的方案数
\[y^n\sum_S g(S)(\frac{1-y}{y})^{|S|}
\]
\[y^n\sum_S\prod_{i=1}^{k}s_i^2n^{2k-4}(\frac{1-y}{y})^{n-|S|}
\]
\[\frac{(1-y)^n}{n^4}\sum_S\prod_{i=1}^{k}\frac{n^2y}{1-y}s_i^2
\]
相当于一个有 \(k\) 的联通块的权值是 \(\frac{n^2y}{1-y}k\),求将树划分成若干个联通块的权值之积之和。
显然这是一个 \(\rm SET\) 构造,但是要注意一下给联通块的内部填上编号,需要乘上一个生成树的数量,所以直接令 \(F(x)=\sum\frac{n^2y}{1-y}i^ix^i\),计算 \(\exp(F(x))\) 即可。
怎么感觉二问比一问简单
#include<cstdio>
#include<cctype>
#include<map>
#define IMP(lim,act) for(int qwq=(lim),i=0;i^qwq;++i)act
const int M=1<<18|5,mod=998244353;
int inv[M<<1],fac[M<<1],ifac[M<<1],buf[M<<1],*w[20];int n,y;
inline int Getlen(const int&n){
int len(0);while((1<<len)<n)++len;return len;
}
inline int Add(const int&a,const int&b){
return a+b>=mod?a+b-mod:a+b;
}
inline int Del(const int&a,const int&b){
return b>a?a-b+mod:a-b;
}
inline void swap(int&a,int&b){
int c=a;a=b;b=c;
}
inline int pow(int a,int b=mod-2){
int ans(1);for(;b;b>>=1,a=1ll*a*a%mod)if(b&1)ans=1ll*ans*a%mod;return ans;
}
inline void init(const int&n){
const int&m=Getlen(n);int*now=buf;w[m]=now;now+=1<<m;
w[m][0]=1;w[m][1]=pow(3,mod-1>>m+1);for(int i=2;i^1<<m;++i)w[m][i]=1ll*w[m][i-1]*w[m][1]%mod;
for(int k=m-1;k>=0&&(w[k]=now,now+=1<<k);--k)IMP(1<<k,w[k][i]=w[k+1][i<<1]);
fac[0]=fac[1]=ifac[0]=ifac[1]=inv[1]=1;for(int i=2;i^1<<m;++i)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
for(int i=1;i<=n;++i)fac[i]=1ll*fac[i-1]*i%mod,ifac[i]=1ll*ifac[i-1]*inv[i]%mod;
}
inline void DFT(int*f,const int&M){
const int&n=1<<M;
for(int len=n>>1,d=M-1;d>=0;--d,len>>=1)for(int k=0;k^n;k+=len<<1){
int*W=w[d],*L=f+(k),*R=f+(k|len),x,y;IMP(len,(x=*L,y=*R)),*L++=Add(x,y),*R++=1ll**W++*Del(x,y)%mod;
}
}
inline void IDFT(int*f,const int&M){
const int&n=1<<M;
for(int len=1,d=0;d^M;++d,len<<=1)for(int k=0;k^n;k+=len<<1){
int*W=w[d],*L=f+(k),*R=f+(k|len),x,y;IMP(len,(x=*L,y=1ll**W++**R%mod)),*L++=Add(x,y),*R++=Del(x,y);
}
const int&k=pow(n);IMP(n,f[i]=1ll*f[i]*k%mod);for(int i=1;(i<<1)<n;++i)swap(f[i],f[n-i]);
}
inline void Inv(int*f,const int&n){
static int b1[M],b2[M],b3[M];const int&m=Getlen(n);b1[0]=pow(f[0]);
for(int len=1;len<=m;++len){
IMP(1<<len-1,b2[i]=2ll*b1[i]%mod);IMP(1<<len,b3[i]=f[i]);
DFT(b1,len+1);DFT(b3,len+1);IMP(1<<len+1,b1[i]=1ll*b1[i]*b1[i]%mod*b3[i]%mod);IDFT(b1,len+1);
IMP(1<<len,b1[i]=Del(b2[i],b1[i])),b3[i]=b3[1<<len|i]=b1[1<<len|i]=0;
}
IMP(n,f[i]=b1[i]);IMP(1<<m,b1[i]=b2[i]=0);
}
inline void Der(int*f,const int&n){
IMP(n-1,f[i]=1ll*f[i+1]*(i+1)%mod);f[n-1]=0;
}
inline void Int(int*f,const int&n){
for(int i=n-1;i>=0;--i)f[i+1]=1ll*f[i]*inv[i+1]%mod;f[0]=0;
}
inline void Ln(int*f,const int&n){
static int g[M];const int&len=Getlen(n+n-2);IMP(n,g[i]=f[i]);Der(f,n);Inv(g,n);
DFT(f,len);DFT(g,len);IMP(1<<len,g[i]=1ll*f[i]*g[i]%mod),f[i]=0;IDFT(g,len);
IMP(n-1,f[i]=g[i]);Int(f,n-1);IMP(1<<len,g[i]=0);
}
inline void Exp(int*f,const int&n){
static int b1[M],b2[M],b3[M];const int&m=Getlen(n);b1[0]=1;
for(int len=1;len<=m;++len){
IMP(1<<len-1,b3[i]=b2[i]=b1[i]);Ln(b2,1<<len);IMP(1<<len,b2[i]=Del(f[i],b2[i]));++b2[0];
DFT(b2,len);DFT(b3,len);IMP(1<<len,b2[i]=1ll*b2[i]*b3[i]%mod);IDFT(b2,len);
IMP(1<<len-1,b1[1<<len-1|i]=b2[1<<len-1|i]);
}
IMP(n,f[i]=b1[i]);IMP(1<<m,b1[i]=b2[i]=b3[i]=0);
}
namespace sub1{
std::map<long long,bool>hash;
inline void main(){
int cnt(n);if(y==1)return printf("1"),void();
for(int u,v,i=1;i<n;++i)scanf("%d%d",&u,&v),hash[1ll*u*(n+1)+v]=hash[1ll*v*(n+1)+u]=1;
for(int u,v,i=1;i<n;++i)scanf("%d%d",&u,&v),cnt-=hash[1ll*u*(n+1)+v];
printf("%d",pow(y,cnt));
}
}
namespace sub2{
int k,ege,f[M],g[M],h[M];
struct Edge{
int v,nx;
}e[M<<1];
inline void Add(const int&u,const int&v){
e[++ege]=(Edge){v,h[u]};h[u]=ege;
e[++ege]=(Edge){u,h[v]};h[v]=ege;
}
inline void DFS(const int&u,const int&fa){
f[u]=g[u]=1;
for(int v,E=h[u];E;E=e[E].nx)if((v=e[E].v)^fa){
DFS(v,u);g[u]=(1ll*f[u]*g[v]+(f[v]+1ll*k*g[v])%mod*g[u])%mod;f[u]=(f[v]+1ll*k*g[v])%mod*f[u]%mod;
}
}
inline void main(){
if(y==1)return printf("%d",pow(n,n-2)),void();
for(int u,v,i=1;i<n;++i)scanf("%d%d",&u,&v),Add(u,v);k=1ll*n*y%mod*pow(mod+1-y)%mod;
DFS(1,0);printf("%d",1ll*pow(mod+1-y,n)*pow(n,mod-3)%mod*k%mod*g[1]%mod);
}
}
namespace sub3{
int k,F[M];
inline void main(){
if(y==1)return printf("%d",pow(n,2*n-4)),void();k=1ll*n*n%mod*y%mod*pow(mod+1-y)%mod;
init(n+1<<1);for(int i=1;i<=n;++i)F[i]=1ll*k*pow(i,i)%mod*ifac[i]%mod;Exp(F,n+1);
printf("%d",1ll*pow(mod+1-y,n)*pow(n,mod-5)%mod*fac[n]%mod*F[n]%mod);
}
}
inline int read(){
int n(0);char s;while(!isdigit(s=getchar()));while(n=n*10+(s&15),isdigit(s=getchar()));return n;
}
signed main(){
n=read();y=read();int opt=read();if(opt==0)sub1::main();if(opt==1)sub2::main();if(opt==2)sub3::main();
}