虚树学习笔记
虚树学习笔记
[SDOI2011]消耗战
题意
给一棵\(n\)个点,带边权的树。
给\(m\)组询问,每组有\(k_i\)个关键点,你需要切断一些边,使得每个点都到不了根节点,求最小代价。\
\(n <= 2.5 \cdot 10^5, m <= 5 \cdot 10^5,\sum k_i <= 5 \cdot 10^5\)
Solve 1
对于每组询问,做一个\(dp\),设\(f[x]\)表示切断\(x\)和他的子树所需最小代价,转移分两种
- \(x\)是关键点,答案为\(x\)到根路径最小值
- \(x\)不是关键点,答案为切断所有儿子的值和第\(1\)种取\(min\)
复杂度\(O(nm)\)
\(can \ we \ do \ better?\)
Solve 2
要用到所讲的虚树。
我们发现转移过程中,对于转移有贡献的只有关键点以及他们之间的祖先,于是我们可以简化树的结构。
把关键点按\(dfs\)序排序,相邻两个求出\(lca\)并建边。最后在虚树上做\(dp\),复杂度\(O(n\ log \ n + \sum k_i log \ n)\)
具体实现用一个栈维护一条树链,排序后一次加入点。
设当前加入的点\(u\)
- 如果\(top <=1\) ,\(stk[++top] = u\)
- 设\(l = lca(u,stk[top])\),如果\(l == stk[top]\),那么\(u\)应该接在\(stk[top]\)底下,\(stk[++top] = u\)
- 否则说明\(u\)已经是一个新的子树,持续弹栈直到\(dfn[stk[top-1]] < dfn[l] <= dfn[stk[top]]\),如果\(l != stk[top]\),把\(stk[top]\)接在\(l\)后面,\(stk[top] = l\),最后\(stk[++top] = u\)
void insert(int u){
if(top <= 1) return stk[++top] = u,void();
int l = lca(u,stk[top]);
if(l == stk[top]) return stk[++top] = u,void();
while(top > 1 && dfn[l] <= dfn[stk[top-1]]){
add(stk[top-1],stk[top]); top--;
}
if(l != stk[top]) add(l,stk[top]),stk[top] = l;
stk[++top] = u;
return ;
}
Code
#include<bits/stdc++.h>
#define int long long
#define N 1000015
#define rep(i,a,n) for (int i=a;i<=n;i++)
#define per(i,a,n) for (int i=n;i>=a;i--)
#define inf 0x3f3f3f3f3f3f3f3f
#define pb push_back
#define mp make_pair
#define pii pair<int,int>
#define fi first
#define se second
#define lowbit(i) ((i)&(-i))
#define VI vector<int>
#define all(x) x.begin(),x.end()
using namespace std;
int n,m,a[N],Min[N],k,dfn[N],clk;
vector<pii> e[N];
VI g[N];
void dfs(int u,int fa){
dfn[u] = ++clk;
for(auto I:e[u]){
int v = I.fi,w = I.se;
if(v == fa) continue;
Min[v] = min(Min[u],w);
dfs(v,u);
}
}
bool cmp(int u,int v){
return dfn[u] < dfn[v];
}
namespace LCA{
int fa[N][24],dep[N];
void Dfs(int u,int f){
fa[u][0] = f; dep[u] = dep[f]+1;
for(auto I:e[u]){
int v = I.fi;
if(v == f) continue;
Dfs(v,u);
}
}
void init(){
rep(j,1,21){
rep(i,1,n){
fa[i][j] = fa[fa[i][j-1]][j-1];
}
}
}
int lca(int u,int v){
if(dep[u] < dep[v]) swap(u,v);
int t = dep[u] - dep[v];
per(i,0,21){
if((1<<i)&t) u = fa[u][i];
}
if(u == v) return u;
per(i,0,21){
if(fa[u][i] != fa[v][i]) u = fa[u][i],v = fa[v][i];
}
return fa[u][0];
}
}
using namespace LCA;
int stk[N],top;
void add(int u,int v){
//printf("%lld -> %lld\n",u,v);
g[u].pb(v);
}
void insert(int u){
if(top <= 1) return stk[++top] = u,void();
int l = lca(u,stk[top]);
if(l == stk[top]) return stk[++top] = u,void();
while(top > 1 && dfn[l] <= dfn[stk[top-1]]){
add(stk[top-1],stk[top]); top--;
}
if(l != stk[top]) add(l,stk[top]),stk[top] = l;
stk[++top] = u;
return ;
}
void build(){
top = 0;
stk[++top] = 1;
rep(i,1,k) insert(a[i]);
while(top > 1) add(stk[top-1],stk[top]),top--;
}
bool gkp[N];
int dp(int u){
int res = 0;
if(g[u].size() == 0){
//printf("u: %lld val: %lld\n",u,Min[u]);
return Min[u];
}
for(auto v:g[u]){
res += dp(v);
}
g[u].clear();
if(!gkp[u]) return min(res,Min[u]);
//printf("u: %lld val: %lld\n",u,res);
return Min[u];
}
signed main(){
//freopen(".in","r",stdin);
//freopen(".out","w",stdout);
scanf("%lld",&n);
memset(Min,0x3f,sizeof Min);
rep(i,2,n){
int u,v,w; scanf("%lld%lld%lld",&u,&v,&w);
e[u].pb(mp(v,w)); e[v].pb(mp(u,w));
}
dfs(1,0);
// rep(i,1,n) printf("%lld ", Min[i]);
// printf("\n");
Dfs(1,0); init();
// rep(i,1,n){
// rep(j,i+1,n){
// printf("(i,j): (%lld,%lld) lca: %lld\n",i,j,lca(i,j));
// }
// }
scanf("%lld",&m);
rep(_,1,m){
scanf("%lld",&k); rep(i,1,k) scanf("%lld",&a[i]),gkp[a[i]] = 1;
sort(a+1,a+k+1,cmp); //puts("sort finished");
build(); //puts("build finished");
printf("%lld\n",dp(1));
rep(i,1,k) gkp[a[i]] = 0;
}
return 0;
}