虚树 学习笔记
虚树 学习笔记
引入
我们在解决树上问题时,往往都是对整棵树进行处理,或者每次询问都对一个点、点对进行处理,这类题型一般都可以通过 dp、树剖解决;然而,有一类问题要求我们每次对树上一些关键点进行处理。这类问题的特点就是询问次数多,而询问的点的总数不多。可如果我们每次都把整棵树都 dfs 一遍,时间复杂度就是
虚树
我们把关键点和用于体现关键点之间关系的辅助点连接起来,就形成了虚树。这些辅助点一般都是 LCA。由于至少需要一个关键点才会出现一个辅助点(比如某个点 是 “两个关键点的 LCA” 与 “另一个关键点” 的 LCA),所以最后建出来所有的虚树并遍历的总代价是
建树过程
我们肯定希望辅助点越少越好,但同时还得保证信息正确,所以我们考虑按照 dfs 序建树,因为 dfs 序越相近,两个点在树上的关系越近。我们先把关键点按 dfn 排序,然后用一个栈来维护一条虚树上的链,每次都询问栈顶是否为新点和栈顶的 LCA,如果不是,说明要开一条新的链,就弹栈并加边。最后一定要把栈内剩余元素加边。
参考代码:
void build(){ sort(p+1, p+K+1, cmp); top = 0; stk[++top] = 1; G2.head[1] = 0;//注意不能全部清空,在加边的过程中动态清空即可。 for(int i = 1; i<=K; ++i){ if(p[i] == 1) continue; int lca = th.LCA(stk[top], p[i]); if(lca != stk[top]){ while(dfn[lca] < dfn[stk[top-1]]){ G2.add(stk[top-1], stk[top]); --top; } if(dfn[lca] > dfn[stk[top-1]]){ G2.head[lca] = 0; G2.add(lca, stk[top]), stk[top] = lca; } else{ G2.add(lca, stk[top]); --top; } } G2.head[p[i]] = 0; stk[++top] = p[i]; } for(int i = 1; i<top; ++i){ G2.add(stk[i], stk[i+1]); } }
例题
消耗战
首先 dp 式子很明显,我们分类讨论。如果子节点
代码:
点击查看代码
#include<bits/stdc++.h> #define ll long long using namespace std; const int N = 2.5e5+10; inline int read(){ int x = 0; char ch = getchar(); while(ch<'0' || ch>'9') ch = getchar(); while(ch>='0'&&ch<='9') x = x*10+ch-48, ch = getchar(); return x; } struct node{ int nxt, to, w; }; struct Graph{ int head[N], tot; int num; node edge[N<<1]; void add(int u, int v, int w){ edge[++tot].nxt = head[u]; edge[tot].to = v; edge[tot].w = w; head[u] = tot; } }G1, G2;//原树,虚树。 int dfn[N]; struct HPD{//重链剖分,heavy path decomposition int siz[N], totd, top[N], son[N], dep[N], fa[N]; void dfs1(int u, int fath){ dep[u] = dep[fath]+1; siz[u] = 1; fa[u] = fath; for(int i = G1.head[u]; i; i = G1.edge[i].nxt){ int v = G1.edge[i].to; if(v == fath) continue; dfs1(v, u); siz[u]+=siz[v]; if(siz[son[u]]<siz[v]) son[u] = v; } } void dfs2(int u, int Top){ top[u] = Top; dfn[u] = ++totd; if(!son[u]) return; dfs2(son[u], Top); for(int i = G1.head[u]; i; i = G1.edge[i].nxt){ int v = G1.edge[i].to; if(!dfn[v]) dfs2(v, v); } } int LCA(int x, int y){ while(top[x] != top[y]){ if(dep[top[x]] < dep[top[y]]) swap(x, y); x = fa[top[x]]; } if(dep[x] > dep[y]) swap(x, y); return x; } }th; int n; int m, K; int dst[N], p[N]; void dfsG1(int u, int fath){ for(int i = G1.head[u]; i; i = G1.edge[i].nxt){ int v = G1.edge[i].to; if(v == fath) continue; dst[v] = min(dst[u], G1.edge[i].w); dfsG1(v, u); } } bool cmp(int a, int b){ return dfn[a] < dfn[b]; } int stk[N], tp; bool is_tar[N]; void build(){ sort(p+1, p+K+1, cmp); tp = 0; stk[++tp] = 1, G2.head[1] = 0; for(int i = 1, l; i<=K; ++i){ if(p[i] == 1) continue; l = th.LCA(p[i], stk[tp]); if(l != stk[tp]){ while(dfn[l] < dfn[stk[tp-1]]){ G2.add(stk[tp-1], stk[tp], dst[stk[tp]]); --tp; } if(dfn[l] > dfn[stk[tp-1]]){ G2.head[l] = 0; G2.add(l, stk[tp], dst[stk[tp]]), stk[tp] = l; } else{ G2.add(l, stk[tp], dst[stk[tp]]); --tp; } } G2.head[p[i]] = 0; stk[++tp] = p[i]; } for(int i = 1; i<tp; ++i){ G2.add(stk[i], stk[i+1], dst[stk[i+1]]); } } ll f[N]; void dfs_ans(int u, int fath){ f[u] = 0; for(int i = G2.head[u]; i; i = G2.edge[i].nxt){ int v = G2.edge[i].to; if(v == fath) continue; dfs_ans(v, u); if(is_tar[v]){ f[u]+=G2.edge[i].w; } else{ f[u]+= min(f[v], 1ll*G2.edge[i].w); } } } int main(){ n = read(); dst[1] = 0x3f3f3f3f; for(int i = 1; i<n; ++i){ int u = read(), v = read(), w = read(); G1.add(u, v, w); G1.add(v, u, w); } th.dfs1(1, 0); th.dfs2(1, 1); dfsG1(1, 0); m = read(); while(m--){ K = read(); G2.tot = 0;//一定注意要清空! for(int i = 1; i<=K; ++i){ p[i] = read(); is_tar[p[i]] = 1; } build(); dfs_ans(1, 0); printf("%lld\n", f[1]); for(int i = 1; i<=K; ++i){ is_tar[p[i]] = 0; } } return 0; }
大工程
也是考虑建好虚树后怎么做。最大值和最小值都可以通过拼接求得,每次找出最大/最小和次大/次小值拼接即可。至于路径权值和,我们考虑每条边的贡献,发现就等于这条边所连接的两棵子树中关键点数量的乘积。至于建好虚树后新边的边权,因为是单位边权,所以直接通过两点的深度做差即可求得。
代码:
点击查看代码
#include<bits/stdc++.h> using namespace std; const int N = 1e6+100; const int INF = 0x3f3f3f3f; inline int read(){ int x = 0; char ch = getchar(); while(ch<'0' || ch>'9') ch = getchar(); while(ch>='0'&&ch<='9') x = x*10+ch-48, ch = getchar(); return x; } struct node{ int nxt, to; }; struct Graph{ int tot, head[N]; node edge[N<<1]; void add(int u, int v){ edge[++tot].nxt = head[u]; edge[tot].to = v; head[u] = tot; } }G1, G2; int dep[N], dfn[N], totd; struct HPD{ private: int fa[N], top[N], son[N], siz[N]; public: void dfs1(int u, int fath){ dep[u] = dep[fath]+1; fa[u] = fath; siz[u] = 1; for(int i = G1.head[u]; i; i = G1.edge[i].nxt){ int v = G1.edge[i].to; if(v == fath) continue; dfs1(v, u); siz[u]+=siz[v]; if(siz[son[u]] < siz[v]) son[u] = v; } } void dfs2(int u, int Top){ dfn[u] = ++totd; top[u] = Top; if(!son[u]) return; dfs2(son[u], Top); for(int i = G1.head[u]; i; i = G1.edge[i].nxt){ int v = G1.edge[i].to; if(!dfn[v]) dfs2(v, v); } } inline int LCA(int x, int y){ while(top[x] != top[y]){ if(dep[top[x]] < dep[top[y]]) swap(x, y); x = fa[top[x]]; } if(dep[x] > dep[y]) swap(x, y); return x; } }th; int K; int stk[N], top; int p[N]; bool is_tar[N]; bool cmp(int x, int y){ return dfn[x] < dfn[y]; } void build(){ sort(p+1, p+K+1, cmp); top = 0; stk[++top] = 1; G2.head[1] = 0; for(int i = 1; i<=K; ++i){ if(p[i] == 1) continue; int lca = th.LCA(stk[top], p[i]); if(lca != stk[top]){ while(dfn[lca] < dfn[stk[top-1]]){ G2.add(stk[top-1], stk[top]); --top; } if(dfn[lca] > dfn[stk[top-1]]){ G2.head[lca] = 0; G2.add(lca, stk[top]), stk[top] = lca; } else{ G2.add(lca, stk[top]); --top; } } G2.head[p[i]] = 0; stk[++top] = p[i]; } for(int i = 1; i<top; ++i){ G2.add(stk[i], stk[i+1]); } } int fmn[N], fmx[N]; long long fsum[N]; int mn, mx; long long sum; void dfs_ans(int u, int fath){ int firmn = INF, secmn = INF; fmn[u] = INF, fmx[u] = 0; int firmx = 0, secmx = 0; fsum[u] = 0; if(is_tar[u]){ fsum[u] = 1; } for(int i = G2.head[u]; i; i = G2.edge[i].nxt){ int v = G2.edge[i].to; if(v == fath) continue; dfs_ans(v, u); if(is_tar[v]){ fmn[u] = min(fmn[u], dep[v]-dep[u]); if(dep[v]-dep[u] < firmn){ secmn = firmn; firmn = dep[v]-dep[u]; } else if(dep[v]-dep[u]<secmn){ secmn = dep[v]-dep[u]; } } else{ fmn[u] = min(fmn[v]+dep[v]-dep[u], fmn[u]); if(fmn[v]+dep[v]-dep[u] < firmn){ secmn = firmn; firmn = fmn[v]+dep[v]-dep[u]; } else if(fmn[v]+dep[v]-dep[u]<secmn){ secmn = fmn[v]+dep[v]-dep[u]; } } fmx[u] = max(fmx[u], fmx[v]+dep[v]-dep[u]); if(fmx[v]+dep[v]-dep[u] > firmx){ secmx = firmx; firmx = fmx[v]+dep[v]-dep[u]; } else if(fmx[v]+dep[v]-dep[u] > secmx){ secmx = fmx[v]+dep[v]-dep[u]; } fsum[u]+=fsum[v]; sum+=(fsum[v]*(K-fsum[v])*(dep[v]-dep[u])); } if(is_tar[u]){ mn = min(mn, fmn[u]); } else{ mn = min(mn, secmn+firmn); } if(secmx){ mx = max(mx, firmx+secmx); } else if(is_tar[u]){ mx = max(fmx[u], mx); } } int n, Q; int main(){ n = read(); for(int i = 1; i<n; ++i){ int u = read(), v = read(); G1.add(u, v); G1.add(v, u); } th.dfs1(1, 0); th.dfs2(1, 1); Q = read(); while(Q--){ K = read(); G2.tot = 0; for(int i = 1; i<=K; ++i){ p[i] = read(); is_tar[p[i]] = 1; } build(); sum = mx = 0, mn = INF; dfs_ans(1, 0); printf("%lld %d %d\n", sum, mn, mx); for(int i = 1; i<=K; ++i){ is_tar[p[i]] = 0; } } return 0; }