bzoj3611

$虚树+树形dp$

$虚树复习$

$虚树是用来解决每次从树中选出一些点进行统计一类问题的方法,每次把这些点建出虚树,最多不超过2*k个点$

$建虚树的过程挺简单的,就是用一个栈记录当前dfs的过程,如果出现分叉就回溯。重点在于虚点的加入,就是两个点的lca,这个判一下就行了$

$树形dp的过程比较简单,就不说了,注意虚点的统计$

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e6 + 5, inf = 0x3f3f3f3f;
int n, q, k, top, dfs_clock;
ll ans, ans1, ans2;
int dfn[N], dep[N], fa[N][21], mark[N], a[N], st[N];
ll mn[N], mx[N], sum[N], size[N];
vector<int> G[N];
vector<pair<int, int> > g[N];
bool cmp(int i, int j) {
    return dfn[i] < dfn[j];
}
void dfs(int u, int last) {
    dfn[u] = ++dfs_clock;
    for(int i = 0; i < G[u].size(); ++i) {
        int v = G[u][i];
        if(v == last) {
            continue;
        }
        fa[v][0] = u;
        dep[v] = dep[u] + 1;
        dfs(v, u);
    }
}
int lca(int u, int v) {
    if(dep[u] < dep[v]) {
        swap(u, v);
    }
    int d = dep[u] - dep[v];
    for(int i = 20; i >= 0; --i) {
        if(d & (1 << i)) {
            u = fa[u][i];
        }
    }
    if(u == v) {
        return u;
    }
    for(int i = 20; i >= 0; --i) {
        if(fa[u][i] != fa[v][i]) {
            u = fa[u][i];
            v = fa[v][i];
        }
    }
    return fa[u][0];
}
void dp(int u) {
    if(mark[u]) {
        mn[u] = 0;
    } else {
        mn[u] = 0x3f3f3f3f;
    }
    mx[u] = 0; 
    sum[u] = 0;
    size[u] = mark[u];
    for(int i = 0; i < g[u].size(); ++i) {
        int v = g[u][i].first, w = -g[u][i].second;
        dp(v);      
        if(mark[u]) {
            ans += size[v] * sum[u] + size[u] * (sum[v] + w * size[v]); 
            ans1 = min(ans1, mn[v] + w);
            ans2 = max(ans2, mx[v] + w + mx[u]); 
        } else {
            ans += size[v] * sum[u] + size[u] * (sum[v] + w * size[v]);
            if(i) ans1 = min(ans1, mn[u] + mn[v] + w);
            if(i) ans2 = max(ans2, mx[v] + w + mx[u]);
        }
        sum[u] += sum[v] + w * size[v];
        mn[u] = min(mn[u], mn[v] + w);
        mx[u] = max(mx[u], mx[v] + w);
        size[u] += size[v];
    }
    g[u].clear();
}
int main() {
    scanf("%d", &n);
    for(int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs(1, 0);
    for(int j = 1; j <= 20; ++j) 
        for(int i = 1; i <= n; ++i) 
            fa[i][j] = fa[fa[i][j - 1]][j - 1];
    scanf("%d", &q);
    while(q--) {
        scanf("%d", &k);
        for(int i = 1; i <= k; ++i) {
            scanf("%d", &a[i]);
            mark[a[i]] = 1;
        }
        sort(a + 1, a + k + 1, cmp);
        st[top = 1] = 1;
        for(int i = 1; i <= k; ++i) {
            int f = lca(a[i], st[top]);
            while(top > 1 && dfn[f] < dfn[st[top - 1]]) {
                g[st[top - 1]].push_back({st[top], dep[st[top - 1]] - dep[st[top]]});
                --top;
            }
            if(dfn[f] < dfn[st[top]]) {
                g[f].push_back({st[top], dep[f] - dep[st[top]]});
                --top;
            }
            if(f != st[top]) {
                st[++top] = f;
            }
            if(a[i] != 1) {
                st[++top] = a[i];
            }
        }
        while(top > 1) {
            g[st[top - 1]].push_back({st[top], dep[st[top - 1]] - dep[st[top]]});
            --top; 
        }
        ans = 0;
        ans1 = inf;
        ans2 = 0;
        dp(1);
        printf("%lld %lld %lld\n", ans, ans1, ans2);
        for(int i = 1; i <= k; ++i) {
            mark[a[i]] = 0;
        }
    }
    return 0;
}
View Code

 

posted @ 2018-01-29 19:35  19992147  阅读(110)  评论(0编辑  收藏  举报