【学习笔记】虚树
虚树学习笔记
前言:前两天北航校赛有一题考了这个,看到题解觉得挺神奇的,之后又在南昌网络赛的题解里看到这个词,赶紧先来把板子学了
update1: 把南昌网络赛的F题补上了,写的时候尝试一下了把虚树写进struct里
先了解几个基本的问题
Q:虚树是做什么的?
A:虚树就是把一棵树上的对问题有影响的点扒出来,建一棵新的树
Q:为什么会需要这个东西?
A:如果题目有多组询问,每次询问的点不同,树上其他的点不会对答案产生影响,每次询问对所有的点跑DFS或者DP,是很浪费时间的,尤其是询问的数量特别多的情况
Q:对每次询问建一棵新的树,这样不会很浪费时间吗?
A:对\(m\)个点,建立虚树的复杂度可以优化到\(O(m\log n)\) (甚至可以优化到\(O(m)\))。对所有询问,总的复杂度就是\(O(\sum{m} \log{n})\), 本篇文章的核心也在于虚树的构建。
Q:如果\(\sum{m}=q*n\)呢?
A:再见
虚树的构造
基本思路:
- 用栈维护一个链
- 按照dfs序,从小到大尝试向虚树中添加点
- 如果当前的点可以链接到 当前栈所维护的链,或者栈中的点少于1(没有链),直接入栈
- 如果当前的点不能链接到栈中的链,即\(lca(u,s.top)\neq s.top\)
取\(v=s.top,g=lca(u,v)\),说明\(u,v\)分别在\(g\)的两个子树上,且\(v\)所在的子树已经全部扒完了(从dfs序的角度考虑
这种情况下,把\(v\)所在的\(g\)子树退栈,边退栈边在虚树中建边,直到\(pre(s.top) <= pre(g)\)
如果\(s.top \neq g\), \(g\)入栈,之前退掉的最后一个点连到\(g\)上 - 所有的点填入之后,栈中保存的是一条链,不在栈中的也通过\(lca\)和这条链相连接
整理一下:
先将询问的点按照\(dfs\)序从小到大排序
对于当前点\(u\):\
- 如果当前栈大小小于1,直接入栈
- 如果\(lca(s.top.u)=s.top\), 直接入栈
- 如果\(lca(s.top,u)\neq s.top\),将\(s.top\)向上直至\(lca\)的链退栈并连边(保留\(lca\)),\(u\)入栈
- 如果\(u\)是最后一个点,清空栈,同时将栈中的点连边
代码如下
sort(p+1, p+1+Vn, cmp);
stak[++top] = 1;
for(int i = 1; i <= Vn; i++){
if(top <= 1){
stak[++top] = p[i];
}
else{
int g = lca(stak[top],p[i]);
if(g == stak[top]){
stak[++top] = p[i];
}
else{
while(top > 1 && pre[stak[top-1]] >= pre[g]){
VG[stak[top-1]].push_back(stak[top]);
top--;
}
if(g != stak[top]){
VG[g].push_back(stak[top]);
stak[top] = g;
}
stak[++top] = p[i];
}
}
}
while(top > 1){
VG[stak[top-1]].push_back(stak[top]);
top--;
}
例题1 [SDOI2011]消耗战
题意:
给定一个树\((2\leq n \leq 250000)\),\(m\)次询问,每次给定\(k_i\)个点,回答最少需要切断的边的权值之和,让\(1\)与给定的\(k_i\)个点不连通
\(\sum{k_i} \leq 500000\)
解法:
- 先考虑单次询问,可以预处理每个点的 令点\(i\)与\(1\)不连通的最小费用\(cost_i\), 然后\(dfs\)即可
- 再考虑多次询问,由于\(\sum{k_i}\)是固定的,显然可以构建虚树解决
参考代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 3e5;
//
struct Edge{
int u,v;
int d;
int nxt;
Edge(int u,int v,int d,int nxt):
u(u),v(v),d(d),nxt(nxt){};
Edge(){};
};
Edge e[maxn*2];
int edge_cnt, head[maxn];
int n, m;
//
//lca
int fa[maxn][20];
int pre[maxn];
int dfs_clock = 0;
int depth[maxn];
int lg[maxn];
//
int cost[maxn];
void dfs(int u){
pre[u] = ++dfs_clock;
for(int i = head[u]; i; i = e[i].nxt){
if(pre[e[i].v])continue;
// 倍增LCA
fa[e[i].v][0] = u;
depth[e[i].v] = depth[u] + 1;
for(int j = 1; j < 20; j++){
fa[e[i].v][j] = fa[ fa[e[i].v][j-1] ][j-1];
if(fa[e[i].v][j] == 0)break;
}
//
// DP处理 切断root到当前子树需要最小的cost
if(u == 1)cost[e[i].v] = e[i].d;
else cost[e[i].v] = min(cost[u], e[i].d);
dfs(e[i].v);
}
return;
}
int lca(int u,int v){
if(depth[u] < depth[v])swap(u,v);
while(depth[u] > depth[v])
for(int j = lg[depth[u]-depth[v]]; j >= 0; j--){
if(depth[fa[u][j]] >= depth[v])u = fa[u][j];
}
while(u != v){
for(int j = lg[depth[u]]; j >= 0; j--){
while(fa[u][j] == fa[v][j] && j > 0)j--;
u = fa[u][j]; v = fa[v][j];
}
}
return u;
}
bool k[maxn];
int p[maxn];
int stak[maxn];
int top = 0;
// 虚树上dfs
vector<int> VG[maxn];
ll solve(int u){
ll cost = 0;
for(int i : VG[u]){
cost += solve(i);
}
VG[u].clear();
if(u == 1)return cost;
if(k[u])return cost[u];
else return min(cost[u],cost);
}
//按照dfs序对结点进行排序
bool cmp(int u,int v){
return pre[u] < pre[v];
}
int main(){
// head
#ifdef FWL
freopen("data.in","r",stdin);
#endif // FWL
ios::sync_with_stdio(false);
cin.tie(0);
//
//init
for(int i = 2; i < maxn; i++){
if(i&(i-1))lg[i] = lg[i-1];
else lg[i] = lg[i-1] + 1;
}
edge_cnt = 1;
//
//read & build tree
cin >> n;
for(int i = 1; i < n; i++){
int u,v; ll d;
cin >> u >> v >> d;
e[++edge_cnt] = Edge(u,v,d,head[u]); head[u] = edge_cnt;
e[++edge_cnt] = Edge(v,u,d,head[v]); head[v] = edge_cnt;
}
depth[1] = 0;
dfs(1);
//
// solve
cin >> m;
cost[1] = 2e9; //第一个结点不可能切的掉的
for(int i = 1; i <= m; i++){
int Vn;
cin >> Vn;
for(int i = 1; i <= Vn; i++){
cin >> p[i];
}
top = 0; // init stack
for(int i = 1; i <= Vn; i++) k[ p[i] ] = true;
sort(p+1, p+1+Vn, cmp);
stak[++top] = 1; //根节点先入栈
for(int i = 1; i <= Vn; i++){
if(top <= 1){
stak[++top] = p[i];
}
else{
int g = lca(stak[top],p[i]);
if(g == stak[top]){
stak[++top] = p[i];
}
else{
while(top > 1 && pre[stak[top-1]] >= pre[g]){
VG[stak[top-1]].push_back(stak[top]);
top--;
}
if(g != stak[top]){
VG[g].push_back(stak[top]);
stak[top] = g;
}
stak[++top] = p[i];
}
}
}
while(top > 1){
VG[stak[top-1]].push_back(stak[top]);
top--;
}
cout << solve(1) << '\n';
//clear
for(int i = 1; i <= Vn; i++) k[ p[i] ] = false;
//
}
return 1; //Don't Cheat
}
例题2 2019南昌网络赛 F.Information Transmitting
题意:
\(T\)组数据\((t\leq5)\),给定一棵树\(n\leq 10^5\),边权为概率,代表两点之间信息传递成功的概率,\(Q\)组询问\(Q\leq 10^5\),每次询问给\(M_1\)个点发出信息,\(M_2\)个点接受信息,输出这\(M_2\)个点收到信息的概率\((\sum{M_1}+\sum{M_2}\leq 500000)\)
解法:
先考虑单次询问:
dfs处理:结点\(i\)只能从以\(i\)为根的子树获取信息情况下的结果
第二次dfs处理:结点\(i\)从父亲得到信息的概率
两次dfs的结果加起来,就是最终的结果
需要注意的地方:概率加减不是简单的\(x+y\)
inline double add(double x,double y){return x+y-x*y;}
inline double sub(double x,double y){return (x-y)/(1.0-y);}
多组询问:看到题目中的限制条件就很显然是虚树了
参考代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e5 + 5;
int n,q;
struct Edge{
int u,v;
double p;
int nxt;
Edge(int u,int v,int p,int nxt):
u(u),v(v),p(double(p)/100.0),nxt(nxt){};
Edge(int u,int v,double p,int nxt):
u(u),v(v),p(p),nxt(nxt){};
Edge(){};
};
int head[maxn],edge_cnt;
Edge e[maxn*2];
int fa[maxn][20];
double possible[maxn][20];
int depth[maxn];
int pre[maxn], dfs_clock;
void dfs(int u){
pre[u] = ++dfs_clock;
for(int i = head[u]; i; i = e[i].nxt){
if(pre[e[i].v])continue;
depth[e[i].v] = depth[u]+1;
fa[e[i].v][0] = u;
possible[e[i].v][0] = e[i].p;
for(int j = 1; j < 20; j++){
fa[e[i].v][j] = fa[ fa[e[i].v][j-1]][j-1];
possible[e[i].v][j] = possible[e[i].v][j-1] * possible[ fa[e[i].v][j-1] ][j-1];
if(fa[e[i].v][j] == 0)break;
}
dfs(e[i].v);
}
}
bool cmp(int u,int v){
return pre[u] < pre[v];
}
int lg[maxn];
int lca(int u,int v){
int k = 0;
if(depth[u] < depth[v]){swap(u,v);k = 1;}
while(depth[u] > depth[v]){
int j = lg[depth[u]-depth[v]];
u = fa[u][j];
}
while(u != v){
for(int j = lg[depth[u]]; j >= 0; j--){
while(fa[u][j] == fa[v][j] && j > 0)j--;
u = fa[u][j]; v = fa[v][j];
}
}
return u;
}
double dist(int u,int v){
if(depth[u] < depth[v])swap(u,v);
double ans = 1;
while(depth[u] > depth[v]){
int j = lg[depth[u] - depth[v]];
ans *= possible[u][j];
u = fa[u][j];
}
assert(u == v);
return ans;
}
inline double add(double x,double y){return x+y-x*y;}
inline double sub(double x,double y){return (x-y)/(1.0-y);}
struct VTree{
int n;
int head[maxn];
Edge e[maxn*2];
double a[maxn];
double b[maxn];
int edge_cnt;
void init(int n){
this->n = n;
memset(head, 0, sizeof(head));
edge_cnt = 1;
}
void add_edge(int u,int v,double p){
e[++edge_cnt] = Edge(u,v,p,head[u]);head[u] = edge_cnt;
}
void add_edge(int u,int v){
double p = dist(u,v);
e[++edge_cnt] = Edge(u,v,p,head[u]);head[u] = edge_cnt;
}
void clear(){
clear(1);
edge_cnt = 1;
}
void clear(int u){
for(int & i = head[u]; i; i = e[i].nxt){
clear(e[i].v);
}
a[u] = b[u] = 0;
}
void dfs1(int u){
for(int i = head[u]; i; i = e[i].nxt){
dfs1(e[i].v);
a[u] = add(a[u], a[e[i].v]*e[i].p);
}
}
void dfs2(int u){
for(int i = head[u]; i; i = e[i].nxt){
if(u == 1)b[e[i].v] = sub(a[u],a[e[i].v]*e[i].p)*e[i].p;
else b[e[i].v] = add(b[u],sub(a[u],a[e[i].v]*e[i].p))*e[i].p;
dfs2(e[i].v);
}
}
double ask(int u){
return add(a[u],b[u]);
}
} VG;
int _stack[maxn],top;
int p[maxn*2],to[maxn];
void fuck(int u = 0){
#ifdef FWL
cout << "DEBUG " << u << endl;
#endif // FWL
}
int mian(){ //Don't copy
#ifdef FWL
freopen("data.in","r",stdin);
freopen("data.out","w",stdout);
#endif // FWL
ios::sync_with_stdio(false);
cin.tie(0);
for(int i = 2; i < maxn; i++){
if(i & (i-1))lg[i] = lg[i-1];
else lg[i] = lg[i-1] + 1;
}
int t;
cin >> t;
for(int kase = 1; kase <= t; kase++){
edge_cnt = 1;dfs_clock = 0;
cin >> n >> q;
memset(head, 0, sizeof(head));
memset(pre, 0, sizeof(pre));
for(int i = 1; i < n; i++){
int u,v,p;
cin >> u >> v >> p;
e[++edge_cnt] = Edge(u,v,p,head[u]); head[u] = edge_cnt;
e[++edge_cnt] = Edge(v,u,p,head[v]); head[v] = edge_cnt;
}
dfs(1);
top = 0;
_stack[++top] = 1;
for(int cnt = 1; cnt <= q; cnt++){
int m1,m2;
cin >> m1 >> m2;
for(int i = 1; i <= m1; i++) cin >> p[i];
for(int i = 1; i <= m1; i++) VG.a[p[i]] = 1.0;
for(int i = 1; i <= m2; i++) {cin >> to[i]; p[m1+i] = to[i];}
sort(p+1,p+m1+m2+1,cmp);
m1 = unique(p+1,p+m1+m2+1)-p-1;
int i = 1;
if(p[i] == 1)i++;
for(; i <= m1; i++){
if(top <= 1){
_stack[++top] = p[i];
}
else {
int g = lca(_stack[top],p[i]);
if(g == _stack[top]){
_stack[++top] = p[i];
}
else{
while(top > 1 && pre[_stack[top-1]] >= pre[g]){
VG.add_edge(_stack[top-1],_stack[top]);
top--;
}
if(_stack[top] != g){
VG.add_edge(g,_stack[top]);
_stack[top] = g;
}
_stack[++top] = p[i];
}
}
}
while(top > 1){
VG.add_edge(_stack[top-1],_stack[top]);
top--;
}
VG.dfs1(1);
VG.dfs2(1);
for(int i = 1; i <= m2; i++){
cout << fixed << setprecision(7) << VG.ask(to[i]) << ' ';
}
cout << '\n';
VG.clear();
}
}
return 0;
}