【BZOJ2286】消耗战(SDOI2011)-虚树+树形DP

测试地址:消耗战
做法:本题需要用到虚树+树形DP。
这题如果只有一个询问,相信大家都会做了,比较裸的树形DP。但是询问次数很大,每次询问都O(n)DP的话,总的时间复杂度就是O(nm),无法承受。但是我们发现,总共涉及的询问点数不大,那么我们迫切需要一个关于k而不是关于n的算法。这时候就要拿出大杀器——虚树了。
虚树其实应该不算是一种数据结构,它是一类树上题的一种处理技巧。想法其实很简单,因为我们只询问k个点,那么我们就只把这k个点建在一棵树上就好了。但由于我们还要维护边上的信息,所以需要一些中间节点的支撑,这些中间节点就是询问点的LCA。我们把询问点按树上的DFS序排序,我们发现对于按DFS序排序的三个点x,y,zLCA(x,z)必定等于LCA(x,y)LCA(y,z)中的一个。所以我们只需要求出DFS中相邻询问点的LCA即可。接下来就是建虚树了,当然我们不能直接在原树上DFS,不然时间复杂度又变O(n)了。我们按照DFS序将标记过的点放入栈中,我们需要时刻保证栈中的元素都在从根出发的一条链上。如果栈顶和当前要加入的点的LCA不等于栈顶,说明当前加入点不是栈顶的子孙,那么我们加一条从次栈顶(就是栈顶下面的一个元素)到栈顶的边,边权就是两点之间路径的最小值,可以倍增求出,然后一直下去直到当前点为某个栈中点的子孙,将当前点放入栈顶。由于每个元素仅入栈一次且出栈一次,所以复杂度是和k相关的。
建出虚树后就可以在虚树上DP了,以上算法总的时间复杂度为O(kilogn),可以通过本题。
以下是本人代码:

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const ll inf=1000000000;
int n,m,k,first[250010]={0},tot=0,tim=0;
int fa[250010][21],dep[250010],b[1000010],order[250010];
int st[250010],top=0;
int firsti[250010]={0},toti=0;
ll f[250010],mn[250010][21];
bool res[250010]={0};
struct edge
{
    int v,next;
    ll w;
}e[500010],ei[500010];

void insert(int a,int b,ll w)
{
    e[++tot].v=b,e[tot].w=w,e[tot].next=first[a],first[a]=tot;
}

void inserti(int a,int b,ll w)
{
    ei[++toti].v=b,ei[toti].w=w,ei[toti].next=firsti[a],firsti[a]=toti;
}

void init(int v)
{
    order[v]=++tim;
    for(int i=first[v];i;i=e[i].next)
        if (e[i].v!=fa[v][0])
        {
            mn[e[i].v][0]=e[i].w;
            fa[e[i].v][0]=v;
            dep[e[i].v]=dep[v]+1;
            init(e[i].v);
        }
}

int lca(int x,int y)
{
    if (dep[x]<dep[y]) swap(x,y);
    for(int i=20;i>=0;i--)
        if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];
    if (x==y) return x;
    for(int i=20;i>=0;i--)
        if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}

ll findup(int x,int y)
{
    ll ans=inf*inf;
    for(int i=20;i>=0;i--)
        if (dep[fa[y][i]]>=dep[x]) ans=min(ans,mn[y][i]),y=fa[y][i];
    return ans;
}

bool cmp(int a,int b)
{
    return order[a]<order[b];
}

void build()
{
    sort(b+1,b+k+1,cmp);
    toti=0;
    for(int i=1;i<k;i++)
        b[k+i]=lca(b[i],b[i+1]);
    b[k<<1]=1;
    sort(b+1,b+(k<<1)+1,cmp);
    top=0;
    for(int i=1;i<=(k<<1);i++)
        if (i==1||b[top]!=b[i]) b[++top]=b[i];
    k=top;
    top=1;st[1]=1;
    for(int i=2;i<=k;i++)
    {
        while (top>1&&lca(st[top],b[i])!=st[top])
        {
            inserti(st[top-1],st[top],findup(st[top-1],st[top]));
            top--;
        }
        st[++top]=b[i];
    }
    while (top>1)
    {
        inserti(st[top-1],st[top],findup(st[top-1],st[top]));
        top--;
    }
}

void dp(int v)
{
    f[v]=0;
    for(int i=firsti[v];i;i=ei[i].next)
    {
        dp(ei[i].v);
        if (res[ei[i].v]) f[v]+=ei[i].w;
        else f[v]+=min(ei[i].w,f[ei[i].v]);
    }
}

int main()
{
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        int u,v;
        ll w;
        scanf("%d%d%lld",&u,&v,&w);
        insert(u,v,w),insert(v,u,w);
    }

    fa[1][0]=fa[0][0]=0;
    mn[1][0]=mn[0][0]=inf*inf;
    dep[1]=1;dep[0]=0;
    init(1);
    for(int i=1;i<=20;i++)
        for(int j=1;j<=n;j++)
        {
            fa[j][i]=fa[fa[j][i-1]][i-1];
            mn[j][i]=min(mn[j][i-1],mn[fa[j][i-1]][i-1]);
        }

    scanf("%d",&m);
    for(int i=1;i<=m;i++)
    {
        scanf("%d",&k);

        for(int j=1;j<=k;j++)
        {
            scanf("%d",&b[j]);
            res[b[j]]=1;
        }
        build();

        dp(1);
        printf("%lld\n",f[1]);

        for(int j=1;j<=k;j++)
            res[b[j]]=firsti[b[j]]=0;
    }

    return 0;
}
posted @ 2018-03-07 11:28  Maxwei_wzj  阅读(87)  评论(0编辑  收藏  举报