虚树(树形dp套路)模板——bzoj2286

虚树的核心就是把关键点和关键点的lca重新生成一棵树,然后在这棵树上进行dp

https://www.cnblogs.com/zwfymqz/p/9175152.html  写的很好的博客

建立虚树的核心代码

void insert(int x){
    if(top==1){stk[++top]=x;return;}
    int lca=LCA(x,stk[top]);
    if(lca==stk[top])return;//这里本来要stk[++top]=x的,但是由于本题特殊性,所以删去优化时间
    while(top>1 && dfn[lca]<=dfn[stk[top-1]])
        add_edge(stk[top-1],stk[top]),--top;
    if(lca!=stk[top])//如果lca不是关键点,把它做成关键点,即把lca和栈顶元素连边,然后栈顶元素出栈,lca进栈
        add_edge(lca,stk[top]),stk[top]=lca;
    stk[++top]=x;
} 

本题的ac代码

#include<bits/stdc++.h>
using namespace std;
#define maxn 250005
#define ll long long 
struct Edge{int to,nxt,w;}e[maxn<<1];
int head[maxn],tot;
void add(int u,int v,int w){
    e[tot].nxt=head[u];e[tot].to=v;e[tot].w=w;head[u]=tot++;
}
//树链剖分
ll Min[maxn];
int fa[maxn],son[maxn],topf[maxn],size[maxn],d[maxn],dfn[maxn],ind;
void dfs1(int u,int pre){
    fa[u]=pre;size[u]=1;
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(v==pre)continue;
        d[v]=d[u]+1;
        Min[v]=min(Min[u],(ll)e[i].w);
        dfs1(v,u);
        size[u]+=size[v];
        if(size[v]>size[son[u]])son[u]=v;
    }
}
void dfs2(int u,int tp){
    topf[u]=tp;dfn[u]=++ind;
    if(!son[u])return;
    dfs2(son[u],tp);
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(!topf[v])
            dfs2(v,v);
    }
}
//求lca
int LCA(int x,int y){
    while(topf[x]!=topf[y]){
        if(d[topf[x]]<d[topf[y]])
            swap(x,y);
        x=fa[topf[x]];
    }
    if(d[x]>d[y])swap(x,y);
    return x;
}
//建立虚树 

int top,stk[maxn];
vector<int>G[maxn];
void add_edge(int u,int v){G[u].push_back(v);}

void insert(int x){
    if(top==1){stk[++top]=x;return;}
    int lca=LCA(x,stk[top]);
    if(lca==stk[top])return;
    while(top>1 && dfn[lca]<=dfn[stk[top-1]])//把lca以下的关键点都入栈 
        add_edge(stk[top-1],stk[top]),--top;
    if(lca!=stk[top])//如果lca不是关键点,把它做成关键点 
        add_edge(lca,stk[top]),stk[top]=lca;
    stk[++top]=x;
} 
//在虚树上dp
ll dfs3(int u){
    if(G[u].size()==0)return Min[u];
    ll res=0;
    for(int i=0;i<G[u].size();i++)
        res+=dfs3(G[u][i]);
    G[u].clear();
    return min(res,Min[u]);
}

int cmp(int a,int b){return dfn[a]<dfn[b];}

int n,m,x,y,z,a[maxn];

int main(){
    memset(head,-1,sizeof head);
    Min[1]=1ll<<60;
    cin>>n;
    for(int i=1;i<n;i++){
        scanf("%d%d%d",&x,&y,&z);
        add(x,y,z);add(y,x,z);
    }
    d[1]=1;
    dfs1(1,0),dfs2(1,1);
    
    cin>>m;
    while(m--){
        int k;scanf("%d",&k);
        for(int i=1;i<=k;i++)scanf("%d",&a[i]);
        sort(a+1,a+1+k,cmp);
        top=1;
        stk[top]=1;G[1].clear();
        
        for(int i=1;i<=k;i++)insert(a[i]);
        while(top)
            add_edge(stk[top-1],stk[top]),top--;
        
        //进行dp 
        cout<<dfs3(1)<<endl;
    }
    return 0;
}
View Code

 

posted on 2019-07-15 23:48  zsben  阅读(279)  评论(0编辑  收藏  举报

导航