虚树学习笔记

虚树定义

虚树是一棵虚拟构建的树,这棵树的特点是只包含关键点以及关键的点,这些点满足在原树中的关系,而其他点和边相当于都做了路径压缩

例题引入

luogu2495 [SDOI2011]消耗战

题目描述

一颗树,上面有 k 个资源点,拆一些边,使得 1 号点不能到达任何资源点。现在要使得拆除的边的权值和最小。总共有 m 次询问,每次给出资源点。

题解

题目中的资源点在虚树中就相当于关键点

当我们不考虑虚树时,我们考虑怎么做这个题

我们发现可以树形DP

\(dp[i]\)表示以\(i\)为根的子树中不与关键点连通的最小代价,\(u\)\(i\)的儿子

则有DP方程

\[dp[i]= \begin{cases} dp[i]+min(dp[u],e[i][u])~~~~u不是关键点\\ dp[i]=dp[i]+e[i][u]~~~~~~~~~~~~u是关键点 \end{cases} \]

很显然,这样的复杂度是\(O(nm)\)的,并不符合我们的要求

我们考虑有没有更优的做法或优化

我们重新观察题面

我们发现关键点的总数量只有\(n\)

又发现我们的算法中其实有很多点的子树中并不包含关键点,也就是说根本不需要算它的\(dp\)

因此,我们知道肯定要做出一棵很小的树来快速解决问题

而这棵树中需要储存的东西只有关键点和他们的\(LCA\),也就是一棵虚树

这棵虚树怎么去建呢(图片来自 oi-wiki.org)

对于一个这样的图

1

红色是关键点,红点和黑点都是虚树中的点,黑边是虚树中的边

1

1

1

通过这几张图,我们具象化的了解了虚树的形状,接下来考虑如何建一棵虚树

我们不能\(O(n^2)\)\(LCA\),不难想到可以按照\(dfs\)序排序后求相邻的\(LCA\)

我们知道,对于一棵虚树,只要保证祖先后代关系不变即可随便加点

因此为了方便我们把根节点加进去

然后,我们来做出一个方案建立虚树

我们开一个单调栈,维护虚树上的一条链

如果当前我们要加进去的节点\(now\),与栈顶节点\(top\)\(LCA\)\(top\)就直接入栈

如果不是,则弹栈直到与\(top\)\(LCA\)\(top\)时将\(now\)入栈

当然,在这个过程中不要忘了将栈顶与弹出的节点连边

当我们把全部过程做完后,虚树也就建好了

这时候,我们重新回到那个题

我们处理出树上每个点到根的路径上的最小值,然后直接按照我们原来的方式\(dp\)就行了

#include <ctime>
#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
#define int long long
#define file(a) freopen(#a".in","r",stdin),freopen(#a".out","w",stdout)
using namespace std;
const int maxn=5e5+5;
int n,N,m,beg[maxn],tot,Min[maxn],fa[maxn][26],dp[maxn],dfn[maxn],cnt,a[maxn],vis[maxn],st[maxn],top,dep[maxn];
struct edge{
    int nex,to,w;
}e[maxn*2];
void add(int x,int y,int z) {
    e[++tot]=(edge){beg[x],y,z};
    beg[x]=tot;
}
vector<int>vec[maxn*2];
void chkmax(int &x,int y) {if (x<y) x=y;}
void chkmin(int &x,int y) {if (x>y) x=y;}
int read() {
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0' || ch>'9') {if (ch=='-') f=-1;ch=getchar();}
    while(ch<='9' && ch>='0') {x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return x*f;
}
void dfs(int now,int Fa) {
    fa[now][0]=Fa;
    dfn[now]=++cnt;
    dep[now]=dep[Fa]+1;
    for (int i=beg[now];i;i=e[i].nex) {
        int nex=e[i].to;
        if (nex==Fa) continue;
        chkmin(Min[nex],e[i].w);
        chkmin(Min[nex],Min[now]);
        dfs(nex,now);
    }
}
bool cmp(int x,int y) {
    return dfn[x]<dfn[y];
}
int LCA(int x,int y) {
    if (dep[x]<dep[y]) swap(x,y);
    for (int i=20;i>=0;i--) {
        int fx=fa[x][i];
        if (dep[fx]>=dep[y]) x=fx;
        if (x==y) return x;
    } 
    for (int i=20;i>=0;i--) {
        if (fa[x][i]!=fa[y][i]) {
            x=fa[x][i];
            y=fa[y][i];
        }
    }
    return fa[x][0];
}
void build() {
    sort(a+1,a+1+N,cmp);
    st[top=1]=1;
    for (int i=1;i<=N;i++) {
        int lca=LCA(st[top],a[i]);
        if (lca!=st[top]) {
            while(dfn[lca]<dfn[st[top-1]]) {
                vec[st[top]].push_back(st[top-1]);
                vec[st[top-1]].push_back(st[top]);
                top--;
            }
            if (lca!=st[top-1]) {
                vec[st[top]].push_back(lca);
                vec[lca].push_back(st[top]);
                st[top]=lca;
            }
            else {
                vec[st[top]].push_back(lca);
                vec[lca].push_back(st[top]);
                top--;
            }
        }
        st[++top]=a[i];
    }
    for (int i=1;i<top;i++) {
        vec[st[i]].push_back(st[i+1]);
        vec[st[i+1]].push_back(st[i]);
    } 
}
void init() {
    for (int j=1;j<=20;j++)
        for (int i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1];
}
int solve(int now,int Fa) {
    int ans=0,res=0;
    for (int i=0;i<vec[now].size();i++) {
        int nex=vec[now][i];
        if (nex==Fa) continue;
        res+=solve(nex,now);
    }
    if (vis[now]) ans=Min[now];
    else ans=min(Min[now],res);
    vec[now].clear();
    vis[now]=0;
    return ans;
}
signed main() {
    file(luogu2495);
    n=read();
    memset(Min,0x3f,sizeof(Min));
    for (int i=1;i<n;i++) {
        int x=read(),y=read(),w=read();
        add(x,y,w);add(y,x,w);
    }
    dfs(1,0);
    init();
    m=read();
    while(m--) {
        N=read();
        for (int i=1;i<=N;i++) a[i]=read(),vis[a[i]]=1;
        build();
        cout<<solve(1,1)<<endl;
        for (int i=1;i<=N;i++) vis[a[i]]=0;
    }
    return 0;
}
/*
 ---
/Y A\
\___/   /
   \   /
    \ /*
—————\ *
      \*____
      |*    \
      |*
      \*
       *
       *
*/
posted @ 2022-03-22 21:40  xzj213  阅读(29)  评论(0编辑  收藏  举报