bzoj2870

点分治

跟路径有关的立马想到了点分治

自然我们需要统计每条路径的答案,根据点分治的过程,考虑每条路径,我们希望当前路径上的点作为最小值

那么我们用树状数组保存之前每个最小值对应的最长长度,跑两遍即可

#include<bits/stdc++.h>
using namespace std;
const int N = 5e4 + 5;
vector<int> G[N];
int n, root, m;
long long ans;
int t[N << 1], vis[N], v[N], mx[N], size[N];
int getsize(int u, int last) {
    int ret = 1;
    for(int i = 0; i < G[u].size(); ++i) {
        int v = G[u][i];
        if(vis[v] || v == last) {
            continue;
        }
        ret += getsize(v, u);
    }
    return ret;
}
void getroot(int u, int last, int S) {
    mx[u] = 0;
    size[u] = 1;
    for(int i = 0; i < G[u].size(); ++i) {
        int v = G[u][i];
        if(v == last || vis[v]) {
            continue;
        }
        getroot(v, u, S);
        mx[u] = max(mx[u], size[v]);
        size[u] += size[v];
    }
    mx[u] = max(mx[u], S - size[u]);
    if(mx[u] < mx[root]) {
        root = u;
    }
}
void update(int x, int d) {
    for(; x; x -= x & -x) {
        if(d == -1) {
            t[x] = 0;
        } else {
            t[x] = max(t[x], d);
        }
    }
}
int query(int x) {
    int ret = 0;
    for(; x <= m; x += x & -x) {
        ret = max(ret, t[x]);
    }
    return ret;
}
void dp(int u, int last, int mn, int len) {
    mn = min(mn, v[u]);
    ans = max(ans, (long long)(len + query(mn + 1)) * mn);
    for(int i = 0; i < G[u].size(); ++i) {
        int v = G[u][i];
        if(vis[v] || v == last) {
            continue;
        }       
        dp(v, u, mn, len + 1);
    }
}
void dfs(int u, int last, int mn, int len, int mp) {
    mn = min(mn, v[u]);
    update(mn + 1, min(len + 1, mp));
    for(int i = 0; i < G[u].size(); ++i) {
        int v = G[u][i];
        if(vis[v] || v == last) {
            continue;
        }
        dfs(v, u, mn, len + 1, mp);
    }
}
void solve(int u) {
    vis[u] = 1;
    update(v[u] + 1, 1);
    for(int i = 0; i < G[u].size(); ++i) {
        int V = G[u][i];
        if(vis[V]) {
            continue;
        }
        dp(V, u, v[u], 1);
        dfs(V, u, v[u], 1, 1e9);
    }
    for(int i = 0; i < G[u].size(); ++i) {
        int V = G[u][i];
        if(vis[V]) {
            continue;
        }
        dfs(V, u, v[u], 1, -1);
    }
    update(v[u], 1);
    for(int i = G[u].size() - 1; i >= 0; --i) {
        int V = G[u][i];
        if(vis[V]) {
            continue;
        }
        dp(V, u, v[u], 1);
        dfs(V, u, v[u], 1, 1e9);
    }
    update(v[u], -1);
    for(int i = G[u].size() - 1; i >= 0; --i) {
        int V = G[u][i];
        if(vis[V]) {
            continue;
        }
        dfs(V, u, v[u], 1, -1);
    }
    for(int i = 0; i < G[u].size(); ++i) {
        int v = G[u][i];
        if(!vis[v]) {
            root = 0;
            getroot(v, u, getsize(v, u));
            solve(root);
        }
    }
}
int main() {
    scanf("%d", &n);
    for(int i = 1; i <= n; ++i) {
        scanf("%d", &v[i]);
        m = max(m, v[i] + 1);
    }
    for(int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    mx[0] = 1e9;
    getroot(1, 0, getsize(1, 0));
    solve(root);
    printf("%lld\n", ans);
    return 0;
}
View Code

 

posted @ 2018-02-24 21:48  19992147  阅读(155)  评论(0编辑  收藏  举报