【UNR 3】配对树(线段树合并)
题意
给定一棵 \(n\) 个结点的树和一个长度为 \(m\) 的结点序列。对于一个大小为偶数的点集 \(S\)(集合元素可重复),定义 \(w(S)\) 为:把 \(S\) 中的点两两匹配,每对匹配的树上距离之和的最小值。现在要对序列中所有长度为偶数的区间 \([l,r]\),求出 \(w(\{a_l,a_{l+1},\cdots,a_r\})\) 的和。
\(n,m\le 10^5\),答案对 \(998244353\) 取模。
分析
首先观察到 \(w(S)\) 可以拆到每条边上算。对于一条边:
- 它两边的子树中都有奇数个 \(S\) 中的点时会被算 \(1\) 次。
- 它两边的子树中都有偶数个 \(S\) 中的点时会被算 \(0\) 次。
因此我们可以钦定一个根,然后对于每个子树,计算有多少个长度为偶数的区间 \([l,r]\) 满足 \(a_l,a_{l+1},\cdots,a_r\) 中有奇数个点在这棵子树中。
我们只讨论 \(l\) 为奇数、\(r\) 为偶数的情况,另一种情况是一样的。设 \(c_{u,i}\) 表示 \(a_1,a_2,\cdots,a_i\) 在 \(u\) 子树内出现的次数。观察到 \([l,r]\) 合法当且仅当 \(c_{u,l-1}\) 与 \(c_{u,r}\) 的奇偶性不同。因此我们只要算出使 \(c_{u,2i}\) 为奇数的 \(i\) 的个数。设它为 \(x\),则 \(u\) 到父亲这条边就要算 \(x\times(\lfloor{n\over2}\rfloor+1-x)\) 次。
我们可以用线段树合并来维护 \(b_{u_i}=c_{u,2i} \bmod 2\)。\(b_u\) 实际上是一个 01 数组 \(b'_u\) 的前缀异或和。因此对于线段树上的一个结点,我们只要记录对应区间上 \(b_u\) 的和以及 \(b'_u\) 中 1 的个数的奇偶性即可。时间复杂度 \(O(n\log n)\)。
实现
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define per(i,a,b) for(int i=a;i>=b;i--)
#define pb push_back
#define md ((l+r)>>1)
using namespace std;
typedef long long ll;
const int mod=998244353,maxn=1e5+5,maxm=maxn*17;
int n,m,L,N,ans,a[maxn],lc[maxm],rc[maxm],cn[maxm],su[maxm],rt[maxn];
struct edge{int v,w;};
vector<edge>G[maxn];
vector<int>V[maxn];
inline int inc(int x,int y){return x+=y-mod,x+=x>>31&mod;}
inline int mul(int x,int y){return ll(x)*y%mod;}
inline int nnd(){++N,lc[N]=rc[N]=cn[N]=su[N]=0;return N;}
void up(int x,int l,int r){
cn[x]=cn[lc[x]]^cn[rc[x]];
su[x]=su[lc[x]]+(cn[lc[x]]?r-md-su[rc[x]]:su[rc[x]]);
}
int merge(int x,int y,int l,int r){
if(!x||!y)return x+y;
if(l==r)return su[x]=cn[x]^=cn[y],x;
lc[x]=merge(lc[x],lc[y],l,md);
rc[x]=merge(rc[x],rc[y],md+1,r);
return up(x,l,r),x;
}
void modify(int&x,int l,int r,int i){
if(!x)x=nnd();
if(l==r){su[x]=cn[x]^=1;return;}
i<=md?modify(lc[x],l,md,i):modify(rc[x],md+1,r,i);up(x,l,r);
}
void dfs(int u,int _v=0,int w=0){
rt[u]=0;
for(int i:V[u])modify(rt[u],1,L,i);
for(edge e:G[u])if(e.v!=_v)dfs(e.v,u,e.w),rt[u]=merge(rt[u],rt[e.v],1,L);
ans=inc(ans,mul(w,mul(su[rt[u]],L-su[rt[u]])));
}
int main(){
scanf("%d%d",&n,&m);
rep(i,1,n-1){
int u,v,w;scanf("%d%d%d",&u,&v,&w);
G[u].pb({v,w}),G[v].pb({u,w});
}
rep(i,1,m)scanf("%d",&a[i]);
rep(k,1,2){
L=(m+k)>>1,N=0;
rep(i,1,n)V[i].clear();
rep(i,1,m)if((i+k+1)>>1<=L)V[a[i]].pb((i+k+1)>>1);
dfs(1);
}
printf("%d\n",ans);
return 0;
}