hdu6133 Army Formations 线段树合并

给你一棵有n个节点的二叉树,每个节点有一个权值,对于一棵子树u,将u的子树中的节点权值从大到小排序,令sz[u]为子树u的大小,

则ans[u] = 1 * a[1] + 2 * a[2] + ... + sz[u] * a[sz[u]],其中a[1] >= a[2] >= ... >= a[u]。求所有节点的答案。

对每个节点建立权值线段树,dfs整棵树,线段树合并

ans[rt] = ans[ls[rt]] + ans[rs[rt]] + size[ls[rt]] * w[rs[rt]],w表示某权值区间的权值和,size表示某权值区间内点的个数。

 

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
const int maxn = 1e5 + 10;
const int maxnode = 2e6 + 10;
struct edge { int to, next; }e[maxn << 1];
int head[maxn], ecnt;
void edge_init() { ecnt = 0; memset(head, -1, sizeof(head)); }
void add(int u, int v) {
    e[ecnt].to = v; e[ecnt].next = head[u]; head[u] = ecnt++;
}
int a[maxn], b[maxn];
int root[maxn];
int sz[maxnode], ls[maxnode], rs[maxnode];
long long ans[maxnode], sum[maxnode];
int tot, m;

int mergeleaf(int u, int v) {
    sz[u] += sz[v];
    sum[u] += sum[v];
    ans[u] = sum[u] / (long long)sz[u] * (long long) sz[u] * (long long) (sz[u] + 1LL) / 2LL;
    return u;
}

int merge(int u, int v, int l, int r) {
    if (!u || !v) return u | v;
    if (l == r) return mergeleaf(u, v);
    int mid = (l + r) >> 1;
    ls[u] = merge(ls[u], ls[v], l, mid);
    rs[u] = merge(rs[u], rs[v], mid + 1, r);
    sz[u] = sz[ls[u]] + sz[rs[u]];
    sum[u] = sum[ls[u]] + sum[rs[u]];
    ans[u] = ans[ls[u]] + ans[rs[u]] + sum[ls[u]] * (long long) sz[rs[u]];
    return u;
}

void update(int x, int &rt, int l, int r) {
    if (!rt) rt = ++tot;
    sum[rt] = ans[rt] = (long long) b[x];
    sz[rt] = 1;
    if (l == r) return;
    int mid = (l + r) >> 1;
    if (x <= mid) update(x, ls[rt], l, mid);
    else update(x, rs[rt], mid + 1, r);
}

void dfs(int u, int fa) {
    update(a[u], root[u], 1, m);
    for (int i = head[u]; i != -1; i = e[i].next) {
        int v = e[i].to;
        if (v == fa) continue;
        dfs(v, u);
        root[u] = merge(root[u], root[v], 1, m);
    }
}

int main() {
    int T, n;
    scanf("%d", &T);
    while (T--) {
        edge_init();
        scanf("%d", &n);
        for (int i = 1; i <= n; ++i) scanf("%d", a + i), b[i] = a[i];
        sort(b + 1, b + 1 + n);
        m = unique(b + 1, b + 1 + n) - (b + 1);
        for (int i = 1; i <= n; ++i) a[i] = lower_bound(b + 1, b + 1 + m, a[i]) - b;
        for (int u, v ,i = 1; i < n; ++i) {
            scanf("%d%d", &u, &v);
            add(u, v); add(v, u);
        }
        tot = 0;
        memset(root, 0, sizeof(root));
        dfs(1, 0);
        for (int i = 1; i <= n; ++i) printf("%lld ", ans[root[i]]);
        puts("");
        for (int i = 1; i <= tot; ++i) ls[i] = rs[i] = sum[i] = ans[i] = sz[i] = 0;
    }
}

 

 

 

posted @ 2017-10-14 11:49  zd11024  阅读(132)  评论(0编辑  收藏  举报