BZOJ 2286 消耗战 - 虚树 + 树型dp

传送门

题目大意:

每次给出k个特殊点,回答将这些特殊点与根节点断开至少需要多少代价。

题目分析:

虚树入门 + 树型dp:
刚刚学习完虚树(好文),就来这道入门题签个到。
虚树就是将树中的一些关键点提取出来,在不改变父子关系的条件下用$O(mlog n) \(组成一颗新树(m特殊点数,n总数),大小为\)O(m)$,以便降低后续dp等的复杂度。
建好虚树过后就可以进行普通的dp了(mn[u]表示原图中u到根节点的最短边长):

\[dp[u] = mn[u] (u是特殊点) \]

\[dp[u] = min(mn[u], \sum{dp[son[u]]}) (u不是特殊点) \]

此题就当做是虚树模板了。

注意每一次重建虚树时不用将1~n的边信息全部清空,不然你会见识到clear的惊人速度(T飞)。

code

#include<bits/stdc++.h>
using namespace std;
#define maxn 250050
#define oo 0x3f3f3f3f
typedef long long ll;
typedef pair<int, ll> pil;
namespace IO{
    inline int read(){
        int i = 0, f = 1; char ch = getchar();
        for(; (ch < '0' || ch > '9') && ch != '-'; ch = getchar());
        if(ch == '-') f = -1, ch = getchar();
        for(; ch >= '0' && ch <= '9'; ch = getchar()) i = (i << 3) + (i << 1) + (ch - '0');
        return i * f;
    }
    inline void wr(ll x){
        if(x < 0) x = -x, putchar('-');
        if(x > 9) wr(x / 10);
        putchar(x % 10 + '0');
    } 
}using namespace IO;
int n, m;
vector<pil> g[maxn];
vector<int> vg[maxn];
ll dp[maxn], mn[maxn];
int dfn[maxn], clk, dep[maxn], vir[maxn], virCnt, par[maxn], rt;
bool key[maxn];
namespace LCA{
    int pos[maxn], top[maxn], son[maxn], sze[maxn], tot, fa[maxn];
    inline void dfs1(int u, int f){
        dfn[u] = ++clk;
        dep[u] = dep[f] + 1;
        fa[u] = f;
        sze[u] = 1;
        for(int i = g[u].size() - 1; i >= 0; i--){
            int v = g[u][i].first;
            if(v == f) continue;
            mn[v] = min(mn[u], g[u][i].second);
            dfs1(v, u);
            sze[u] += sze[v];
            if(sze[v] > sze[son[u]] || !son[u]) son[u] = v;
        }
    }
    inline void dfs2(int u, int f){
        if(son[u]){
            pos[son[u]] = ++tot;
            top[son[u]] = top[u];
            dfs2(son[u], u);
        }
        for(int i = g[u].size() - 1; i >= 0; i--){
            int v = g[u][i].first;
            if(v == f || v == son[u]) continue;
            pos[v] = ++tot;
            top[v] = v;
            dfs2(v, u);
        }
    }
    inline void splitTree(){
        dfs1(1, 0);
        pos[tot = 1] = top[1] = 1;
        dfs2(1, 0);
    }
    inline int getLca(int u, int v){
        while(top[u] != top[v]){
            if(dep[top[u]] < dep[top[v]]) swap(u, v);
            u = fa[top[u]];
        }
        return dep[u] < dep[v] ? u : v;
    }
}
 
inline bool cmp(int u, int v){
    return dfn[u] < dfn[v];
}
 
inline void buildVir(){
    static int stk[maxn], top;
    top = 0;
    sort(vir + 1, vir + virCnt + 1, cmp);
    int oriSze = virCnt;
    for(int i = 1; i <= oriSze; i++){
        int u = vir[i];
        if(!top){
            stk[++top] = u;
            par[u] = 0;
            continue;
        }
        int lca = LCA::getLca(stk[top], u);
        while(dep[lca] < dep[stk[top]]){
            if(dep[stk[top - 1]] < dep[lca]) par[stk[top]] = lca;
            --top;
        }
        if(lca != stk[top]){
            vir[++virCnt] = lca;
            par[lca] = stk[top];
            stk[++top] = lca;
        }
        par[u] = lca;
        stk[++top] = u;
    }
    for(int i = 1; i <= virCnt; i++) vg[vir[i]].clear();
    for(int i = 1; i <= virCnt; i++){
        int u = vir[i];
        key[u] = ((i <= oriSze) ? 1 : 0);
        if(par[u]) vg[par[u]].push_back(u);
    }
    sort(vir + 1, vir + virCnt + 1, cmp);
}
 
inline void DP(int u){
//  cout<<u<<"!";
    ll ret = 0;
    for(int i = vg[u].size() - 1; i >= 0; i--){
        int v = vg[u][i];
        DP(v);
        ret += dp[v];
    }
    if(key[u]) dp[u] = mn[u];
    else dp[u] = min(mn[u], ret);
}
 
inline void solve(){
    buildVir();
    DP(vir[1]);
    wr(dp[vir[1]]);
    putchar('\n');
}
 
int main(){
	freopen("h.in", "r", stdin);
    n = read();
    for(int i = 1; i < n; i++){
        int x = read(), y = read();
        ll c = 1ll * read();
        g[x].push_back(pil(y, c));
        g[y].push_back(pil(x, c));
    }
    memset(mn, oo, sizeof mn);
    LCA::splitTree();
    m = read();
    for(int i = 1; i <= m; i++){
        int k = read();
        virCnt = 0;
        for(int j = 1; j <= k; j++)
            vir[++virCnt] = read();
        solve();
    }
    return 0;
}
posted @ 2017-10-22 21:39  CzYoL  阅读(217)  评论(0编辑  收藏  举报