虚树详解
我们先从一道经典的例题入手:
[SDOI2011]消耗战
题意:
给出一棵树,每条边有边权。
有m次询问,每次询问给出k个点,问使得这k个点均不与1号点(根节点)相连的最小代价
\(1\leq n\leq 2.5\times 10^5,1\leq m\leq 5\times 10^5,1\leq \sum\limits k\leq 5\times 10^5\)
暴力dp:
设dp[u]为以u为根的子树中,割掉所有给定点的最小代价
转移的时候要分两种情况:
1.若u不是给定点,则dp[u] = min(u到根节点的所有边的最小边长,割掉所有含有给定点的子树)
2.若u是给定点,显然他必须与1号点分离,所以dp[u]=u到根节点的所有边的最小边长
这为什么对呢?为什么不会重复计算呢?
因为我每次取min,如果一条边被选过一次了,dp值\(\geq\)只选一次这条边的值
每次都对整棵树dp,复杂度是\(O(nm)\)的
这部分代码:(可能是我写的比较丑?好像有人可以50,我只有30)
void dfs1(int x,int fa){
for (int i = ed1.head[x];i;i = ed1.nxt[i]){
int to = ed1.to[i];
if (to == fa) continue;
minn[to] = min(minn[x],ed1.w[i]);
dfs1(to,x);
}
}
void dfs2(int x,int fa){
int res = 0;
for (int i = ed1.head[x];i;i = ed1.nxt[i]){
int to = ed1.to[i];
if (to == fa) continue;
//cout<<"-----"<<x<<" "<<to<<" "<<ed1.w[i]<<endl;
dfs2(to,x);
res += dp[to];
}
if (vis[x]) dp[x] = minn[x];
else dp[x] = min(minn[x],res);
}
虚树优化
上述复杂度肯定是不行的,我们发现\(\sum\limits k\)比较小,那么我们从这里入手,来建虚树
虚树的主要思想是:对于一棵树,仅仅保留有用的点,重新构建一棵树
这里有用的点指的是询问点和它们的lca
构建:
首先我们要先对整棵树dfs一遍,求出他们的dfs序,然后对每个节点以dfs序为关键字从小到大排序
同时维护一个栈,表示从根到栈顶元素这条链
-
如果栈为空,那么显然st[1] = x;
-
取LCA = lca(x,st[top]),如果LCA\(\neq\)st[top],将lca底下的链边删边连边
-
如果删完发现LCA不在栈中,将LCA加入栈中,然后再把x加入栈中
void build(int x){
if (top == 0){st[top = 1] = x;return;}
int LCA = lca(x,st[top]);
while (top > 1&&dep[LCA] < dep[st[top-1]]) ed2.add(st[top-1],st[top]),top--;
if (dep[LCA] < dep[st[top]]) ed2.add(LCA,st[top--]);
if (top == 0||LCA != st[top]) st[++top] = LCA;
st[++top] = x;
}
for (int i = 1;i <= q;i++) build(k[i]);
if (top) while (--top) ed2.add(st[top],st[top+1]);
复杂度
因为每次加入新的节点,最多会产生一个新的LCA,那么点数是\(2\times k\)的,复杂度为\(O(2k)\)
完整代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define int long long
#define ll long long
using namespace std;
int read(){
int x = 1,a = 0;char ch = getchar();
while (ch < '0'||ch > '9'){if (ch == '-') x = -1;ch = getchar();}
while (ch >= '0'&&ch <= '9'){a = a*10+ch-'0';ch = getchar();}
return x*a;
}
const int maxn = 5e5+7,inf = 1e18+7;
int n,m;
bool vis[maxn];
struct node{
int to[maxn],nxt[maxn],tot,head[maxn],w[maxn];
node(){tot = 0;memset(head,0,sizeof(head));}
void add(int x,int y,int z){
to[++tot] = y,nxt[tot] = head[x],w[tot] = z,head[x] = tot;
to[++tot] = x,nxt[tot] = head[y],w[tot] = z,head[y] = tot;
}
void add(int x,int y){
to[++tot] = y,nxt[tot] = head[x],head[x] = tot;
to[++tot] = x,nxt[tot] = head[y],head[y] = tot;
}
}ed1,ed2;
int st[maxn],top,minn[maxn],dfn[maxn],f[maxn][30],siz[maxn],dep[maxn],cnt;
void dfs(int x){
dep[x] = dep[f[x][0]]+1,dfn[x] = ++cnt,siz[x] = 1;
for (int i = 1;i <= 22;i++) f[x][i] = f[f[x][i-1]][i-1];
for (int i = ed1.head[x];i;i = ed1.nxt[i]){
int to = ed1.to[i];
if (to == f[x][0]) continue;
f[to][0] = x;
minn[to] = min(minn[x],ed1.w[i]);
dfs(to);
siz[x] += siz[to];
}
}
int lca(int x,int y){
if (dep[x] > dep[y]) swap(x,y);
for (int i = 22;i >= 0;i--){
if (dep[f[y][i]] >= dep[x]) y = f[y][i];
}
if (x == y) return x;
for (int i = 22;i >= 0;i--){
if (f[x][i] != f[y][i]) x = f[x][i],y = f[y][i];
}
return f[x][0];
}
bool cmp(int x,int y){return dfn[x] < dfn[y];}
void build(int x){
if (top == 0){st[top = 1] = x;return;}
int LCA = lca(x,st[top]);
while (top > 1&&dep[LCA] < dep[st[top-1]]) ed2.add(st[top-1],st[top]),top--;
if (dep[LCA] < dep[st[top]]) ed2.add(LCA,st[top--]);
if (top == 0||LCA != st[top]) st[++top] = LCA;
st[++top] = x;
}
ll dp[maxn];
void dfs1(int x,int fa){
int res = 0;
for (int i = ed2.head[x];i;i = ed2.nxt[i]){
int to = ed2.to[i];
if (to == fa) continue;
dfs1(to,x);
res += dp[to];
}
if (vis[x]) dp[x] = minn[x];
else dp[x] = min(minn[x],res);
vis[x] = ed2.head[x] = 0;
}
int k[maxn];
void init(){ed2.tot = top = 0;}
signed main(){
n = read();
for (int i = 1;i < n;i++){
int x = read(),y = read(),z = read();
ed1.add(x,y,z);
}
for (int i = 1;i <= n;i++) minn[i] = inf;
m = read();
dfs(1);
while (m--){
init();
int q = read();
for (int i = 1;i <= q;i++) k[i] = read(),vis[k[i]] = 1;
k[++q] = 1;
sort(k+1,k+q+1,cmp);
for (int i = 1;i <= q;i++) build(k[i]);
if (top) while (--top) ed2.add(st[top],st[top+1]);
dfs1(1,0);
printf("%lld\n",dp[1]);
}
return 0;
}