WC2019 数树
数树
本题包含三个问题:
-
问题 0:已知两棵 \(n\) 个节点的树的形态(两棵树的节点标号均为 \(1\) 至 \(n\)),其中第一棵树是红树,第二棵树是蓝树。要给予每个节点一个 \([1, y]\) 中的整数,使得对于任意两个节点 \(p, q\),如果存在一条路径 \((a_1 = p, a_2, \cdots , a_m = q)\) 同时属于这两棵树,则 \(p, q\) 必须被给予相同的数。求给予数的方案数。
- 存在一条路径同时属于这两棵树的定义见「题目背景」。
-
问题 1:已知蓝树,对于红树的所有 \(n^{n−2}\) 种选择方案,求问题 0 的答案之和。
-
问题 2:对于蓝树的所有 \(n^{n−2}\) 种选择方案,求问题 1 的答案之和。
提示:\(n\) 个节点的树一共有 \(n^{n−2}\) 种。
在不同的测试点中,你将可能需要回答不同的问题。我们将用 \(\text{op}\) 来指代你需要回答的问题编号(对应上述 0、 1、 2)。
由于答案可能很大,因此你只需要输出答案对 \(998, 244, 353\) 取模的结果即可。
所有测试点均满足 \(3 \le n \le 10^5, 1 \le y \lt 998244353, \text{op} \in \{0, 1, 2\}\)。
问题 0
算一算两树的公共边数即可。
连接的路径相同说明路径上的边都是共有的边。贡献折算下来一条公共边会导致自由变量的个数减一。
set统计,时间复杂度 \(O(n \log n)\)。
namespace Subtask0{
set<pair<int,int> > edge;
void main(int n,int y){
for(int i=1;i<n;++i){
int u=read<int>(),v=read<int>();
if(u>v) swap(u,v);
edge.emplace(u,v);
}
int cnt=n;
for(int i=1;i<n;++i){
int u=read<int>(),v=read<int>();
if(u>v) swap(u,v);
if(edge.count(make_pair(u,v))) --cnt;
}
printf("%d\n",fpow(y,cnt));
}
}
问题 1
https://blog.csdn.net/qq_39972971/article/details/88935680
考虑枚举蓝树上的一个边集 \(S\),强制红树上同样存在这些边,计算将剩余 \(n-|S|\) 个连通块连成一棵树的方案数,更新答案。
由问题 0可知,我们只关心公共边的数量。
上述算法中枚举大小为 \(S\) 的边集计算的时候,一个恰好包含 \(T\) 条蓝边的红树会被计算 \(\binom{T}{S}\) 次,并且由于最终的生成树上的蓝边数尚不确定,我们也无法得知需要乘上 \(y\) 的多少次方。
注意到
若看作每选取一条蓝边产生 \(\times (y^{-1}-1)\) 的贡献,我们就可以直接应用上述算法,只需将最终答案乘以 \(y^n\) 即可。
这个构造简直绝了。我做容斥题第一次见到逆向运用二项式定理的。
总方案数除以被限制的个数也是一种常用的套路。
由Prufer序列,将 \(m\) 个大小分别为 \(a_1,a_2,\dots,a_m~(\sum a_i=n)\) 的连通块连成一棵树的方案数为 \(n^{m-2}\prod a_i\)。\(n^{m-2}\) 部分的贡献可以看做每选取一条蓝边产生 \(\times n^{-1}\) 的贡献,最后将答案乘以 \(n^{n-2}\)。\(\prod a_i\) 部分的贡献对应的组合意义为在每一个连通块内选择恰好一个代表点的方案数。
我们给每个连通块分配一个标号,然后把这个标号拿去做Prufer序列。
假设生成了一个长度为 \(m-2\) 的序列 \(P\),那么方案数为 \(\prod_{i=1}^m a_i^{t_i+1}\)。其中 \(t_i\) 表示 \(i\) 号连通块在序列 \(P\) 中的出现次数。\(i\) 总共连了 \(deg_i=t_i+1\) 条边,而一次在 \(i,j\) 之间的连边的贡献是 \(a_i\times a_j\),所以方案数是那个连乘式。
因此总方案数为
\[tot=\sum_{P}\prod_{i=1}^m a_i^{t_i+1}\\ =\prod_{i=1}^m a_i\sum_{P}\prod_{i=1}^m a_i^{t_i} \]考虑右边那个求和的意义。相当于给你 \(m-2\) 个空,每个空里面可以填 \(a_1\sim a_m\),一种填法的贡献是所有空里面填的数的乘积,然后我们要对所有填法的贡献求和。
那么总方案数可以认为是
\[tot=\prod_{i=1}^m a_i (\sum_{i=1}^ma_i)^{m-2}\\ =n^{m-2} \prod_{i=1}^m a_i \]
又是一波构造。给公式赋予比较容易处理的组合意义,第一次见到运用。
生成树的数量是 \(n^{n-2}\),而又有 \(n-m\) 条蓝边,所以给每条蓝边分配一个 \(n^{-1}\) 的贡献,这样最后在外面乘上一个 \(n^{n-2}\) 就可以凑出 \(n^{m-2}\)。
至此,我们可以设计一个简单树形DP来加速上述枚举算法。
记 \(dp(i,0/1)\) 表示 \(i\) 所在连通块有没有选出代表点的情况下,其子树内上述算法所有情况的贡献总和,可以 \(O(1)\) 简单转移。
\[w=(y^{-1}-1)n^{-1}\\ dp(u,0)=dp(u,0)\times dp(v,1)+dp(u,0)\times dp(v,0)\times w\\ dp(u,1)=dp(u,0)\times dp(v,1)\times w+dp(u,1)\times dp(v,0)\times w+dp(u,1)\times dp(v,1) \]
时间复杂度 \(O(n)\)。
namespace Subtask1{
CO int N=100000+10;
vector<int> to[N];
int weight,dp[N][2];
void dfs(int u,int fa){
dp[u][0]=dp[u][1]=1;
for(int v:to[u])if(v!=fa){
dfs(v,u);
int res[2];
res[0]=mul(dp[u][0],dp[v][1]);
res[0]=add(res[0],mul(dp[u][0],mul(dp[v][0],weight)));
res[1]=mul(dp[u][0],mul(dp[v][1],weight));
res[1]=add(res[1],mul(dp[u][1],mul(dp[v][0],weight)));
res[1]=add(res[1],mul(dp[u][1],dp[v][1]));
dp[u][0]=res[0],dp[u][1]=res[1];
}
}
void main(int n,int y){
for(int i=1;i<n;++i){
int u=read<int>(),v=read<int>();
to[u].push_back(v),to[v].push_back(u);
}
weight=mul(add(fpow(y,mod-2),mod-1),fpow(n,mod-2));
dfs(1,0);
int ans=mul(dp[1][1],mul(fpow(y,n),fpow(n,n-2)));
printf("%d\n",ans);
}
}
问题 2
由问题 1的算法,我们可以设计一个简单DP来完成该问题。
记 \(f_i\) 表示共有 \(i\) 个点的情况下,问题 1的枚举算法所有情况的贡献总和。
枚举 \(1\) 号点所在连通块大小 \(j\) ,则有:
最终答案即为 \(ans=f_n\times y^n \times n^{2n-4}\)。
红蓝树此时等价。我们还是去统计每个连通块的贡献,但是由于此时对于连通块的限制只有点数,所以我们的连通块其实就是任意的生成树。
此时我们需要在两个图中把所有连通块连成一棵树,所以边的代价变化了 \(\times n^{-2}\)。代表点也需要在两个图中都选一个。
直接使用分治NTT优化DP,时间复杂度 \(O(n \log^2 n)\)。
考场上集训队选手的最佳做法。
记 \(F(x)\) 为 \(f\) 的指数型生成函数,
上述转移方程即
时间复杂度 \(O(n\log n)\)。
CO int N=262144+10;
int fac[N],inv[N],ifac[N];
void NTT(poly&a,int dir){
static int rev[N],omg[N];
int lim=a.size(),len=log2(lim);
for(int i=0;i<lim;++i) rev[i]=rev[i>>1]>>1|(i&1)<<(len-1);
for(int i=0;i<lim;++i)if(i<rev[i]) swap(a[i],a[rev[i]]);
omg[0]=1,omg[1]=fpow(dir==1?3:332748118,(mod-1)/lim);
for(int i=2;i<lim;++i) omg[i]=mul(omg[i-1],omg[1]);
for(int i=1;i<lim;i<<=1)
for(int j=0;j<lim;j+=i<<1)
for(int k=0;k<i;++k){
int t=mul(omg[lim/(i<<1)*k],a[j+i+k]);
a[j+i+k]=add(a[j+k],mod-t),a[j+k]=add(a[j+k],t);
}
if(dir==-1){
for(int i=0;i<lim;++i) a[i]=mul(a[i],inv[lim]);
}
}
poly inver(poly a){
int n=a.size();
poly b(1,fpow(a[0],mod-2));
if(n==1) return b;
int lim=2;
for(;lim<n;lim<<=1){
poly a1(a.begin(),a.begin()+lim);
a1.resize(lim<<1),NTT(a1,1);
b.resize(lim<<1),NTT(b,1);
for(int i=0;i<lim<<1;++i) b[i]=mul(add(2,mod-mul(a1[i],b[i])),b[i]);
NTT(b,-1),b.resize(lim);
}
a.resize(lim<<1),NTT(a,1);
b.resize(lim<<1),NTT(b,1);
for(int i=0;i<lim<<1;++i) b[i]=mul(add(2,mod-mul(a[i],b[i])),b[i]);
NTT(b,-1),b.resize(n);
return b;
}
poly differ(CO poly&a){
poly b(a.size()-1);
for(int i=0;i<(int)b.size();++i) b[i]=mul(a[i+1],i+1);
return b;
}
poly inter(CO poly&a){
poly b(a.size()+1);
for(int i=1;i<(int)b.size();++i) b[i]=mul(a[i-1],inv[i]);
return b;
}
poly log(poly a){
int n=a.size();
poly b=inver(a);
a=differ(a);
int lim=1<<int(ceil(log2(2*n-2)));
a.resize(lim),NTT(a,1);
b.resize(lim),NTT(b,1);
for(int i=0;i<lim;++i) a[i]=mul(a[i],b[i]);
NTT(a,-1),a.resize(n);
a=inter(a),a.resize(n);
return a;
}
poly exp(poly a){
int n=a.size();
poly b(1,1);
if(n==1) return b;
int lim=2;
for(;lim<n;lim<<=1){
poly a1(a.begin(),a.begin()+lim);
a1.resize(lim<<1),NTT(a1,1);
b.resize(lim);poly b1=log(b);
b1.resize(lim<<1),NTT(b1,1);
b.resize(lim<<1),NTT(b,1);
for(int i=0;i<lim<<1;++i) b[i]=mul(add(1,add(a1[i],mod-b1[i])),b[i]);
NTT(b,-1),b.resize(lim);
}
a.resize(lim<<1),NTT(a,1);
b.resize(lim);poly b1=log(b);
b1.resize(lim<<1),NTT(b1,1);
b.resize(lim<<1),NTT(b,1);
for(int i=0;i<lim<<1;++i) b[i]=mul(add(1,add(a[i],mod-b1[i])),b[i]);
NTT(b,-1),b.resize(n);
return b;
}
namespace Subtask2{
void main(int n,int y){
fac[0]=1;
for(int i=1;i<N;++i) fac[i]=mul(fac[i-1],i);
inv[0]=inv[1]=1;
for(int i=2;i<N;++i) inv[i]=mul(mod-mod/i,inv[mod%i]);
ifac[0]=1;
for(int i=1;i<N;++i) ifac[i]=mul(ifac[i-1],inv[i]);
int weight=mul(add(fpow(y,mod-2),mod-1),fpow(n,2*mod-4));
poly coef(n+1);
for(int i=0;i<=n;++i) coef[i]=mul(fpow(i+1,i+1),mul(fpow(weight,i),ifac[i]));
poly lnres=inter(coef);lnres.resize(n+1);
poly res=exp(lnres);
int ans=mul(res[n],fac[n]);
ans=mul(ans,mul(fpow(y,n),fpow(n,2*n-4)));
printf("%d\n",ans);
}
}