【树结构】树链剖分简单分析

【树结构】树链剖分

当我们需要在一棵树上完成某些区间操作,而且要求复杂度严格保持在 lg 级别,那么树链剖分往往是不错的选择。

所谓树链剖分,就是把树分割成链,把每条链放到线段树或其他数据结构里面维护。显然,只要我们保证每次区间操作涉及的链的个数为 O(lgn) ,就可以保证总查询或修改复杂度为O(lg2n)。一种常用的分割方式是 “轻重剖分” ,相关资料网上可以查询到。

对于一个区间查询如“对a, b最短路径上的所有节点权值求和”,只需要用倍增处理出c = LCA(a, b)转化为对于一个节点和他祖先节点的区间求和。之后只需要不断检查c是否在当前区间内。如果在直接调用数据结构求和,如果不在则求这一区间的总和,并将节点向上推,直到depth[a] < depth[c]。

以zjoi的树的统计一题为例给出代码:

#include <bits/stdc++.h>
using namespace std;
const int maxm = 100005, maxn = 100005;

int zkw[maxn * 4], sum[maxn * 4], N = 131072, t = 0;
void update(int i, int k)
{
    i += N - 1; zkw[i] = sum[i] = k;
    for (i >>= 1; i; i >>= 1) {
        zkw[i] = max(zkw[i << 1], zkw[(i << 1) + 1]);
        sum[i] = sum[i << 1] + sum[(i << 1) + 1];
    }
}

pair<int, int> query(int i, int j)
{
    int ans = 0, maxi = INT_MIN, p, q;
    for (p = i + N - 1, q = j + N - 1; p < q; p >>= 1, q >>= 1) {
        if (p & 1) { ans += sum[p], maxi = max(maxi, zkw[p]); p++; }
        if (!(q & 1)) { ans += sum[q], maxi = max(maxi, zkw[q]); q--; }
    }
    if (p == q) ans += sum[p], maxi = max(maxi, zkw[p]);
    return make_pair(ans, maxi);
}

struct node {
    int to, next;
    node() { to = next = 0; }
} edge[2 * maxm];
int head[maxn], top = 0;
int dat[maxn], siz[maxn], id[maxn], ind[maxn], hev[maxn], dep[maxn];
int father[maxn][35];
int n, m;

void push(int i, int j) { edge[++top].to = j; edge[top].next = head[i]; head[i] = top; }
int dfs1(int i)
{
    siz[i] = 1;
    for (int k = head[i]; k; k = edge[k].next) {
        if (!dep[edge[k].to]) {
            dep[edge[k].to] = dep[i] + 1; father[edge[k].to][0] = i;
            siz[i] += dfs1(edge[k].to);
        }
    }
    return siz[i];
}
void dfs2(int i, int from)
{
    ind[i] = from; update(++t, dat[i]); id[i] = t;
    if (!head[i]) return;
    hev[i] = 0;
    for (int k = head[i]; k; k = edge[k].next) {
        if (dep[edge[k].to] > dep[i] && siz[edge[k].to] > siz[hev[i]])
            hev[i] = edge[k].to;
    }
    if (!hev[i]) {return;}
    dfs2(hev[i], from);
    for (int k = head[i]; k; k = edge[k].next)
        if (dep[edge[k].to] > dep[i] && edge[k].to != hev[i])
            dfs2(edge[k].to, edge[k].to);
}

void travel(int, int);
void init()
{
    dep[1] = 1;
    memset(father, 0, sizeof father);
    dfs1(1);
    dfs2(1, 1);
    for (int j = 1; j <= 20; j++)
        for (int i = 1; i <= n; i++)
            father[i][j] = father[father[i][j-1]][j-1];
}
inline int lowbit(int i) { return i&(-i); }
int lca(int a, int b)
{
    if (dep[a] < dep[b]) swap(a, b);
    int dd = dep[a] - dep[b];
    while (dd) { a = father[a][(int)(log2(lowbit(dd)))]; dd -= lowbit(dd); }
    if (a == b) return a;
    for (int i = 20; i >= 0; i--)
        if (father[a][i] != father[b][i])
            a = father[a][i], b = father[b][i];
    return father[a][0];
}

int query_sum(int i, int j) // j is anc of i
{
    if (dep[i] < dep[j]) return 0;
    if (dep[ind[i]] <= dep[j])
        return query(id[j], id[i]).first;
    return query_sum(father[ind[i]][0], j) + query(id[ind[i]], id[i]).first;
}

int query_max(int i, int j)
{
    if (dep[i] < dep[j]) return INT_MIN;
    if (dep[ind[i]] <= dep[j])
        return query(id[j], id[i]).second;
    return max(query_max(father[ind[i]][0], j), query(id[ind[i]], id[i]).second);
}

inline void change(int i, int j) { update(id[i], j); }
inline int read() { int a; scanf("%d", &a); return a; }

int main()
{
    memset(dep, 0, sizeof dep);
    memset(head, 0, sizeof head);
    memset(hev, 0, sizeof hev);
    memset(sum, 0, sizeof sum);
    memset(zkw, -127/3, sizeof zkw);
    n = read();
    for (int i = 1; i < n; i++) {
        int a, b; a = read(); b = read();
        push(a, b);
        push(b, a);
    }
    for (int i = 1; i <= n; i++)
        dat[i] = read();
    init();
    m = read();
    char str[10]; int a, b, c;
    for (int i = 1; i <= m; i++) {
        scanf("%s", str);
        a = read(); b = read();
        if (strcmp(str, "CHANGE") == 0) change(a, b);
        else if (strcmp(str, "QSUM") == 0) {
            c = lca(a, b);
            printf("%d\n", query_sum(a, c)+query_sum(b, c)-query_sum(c, c));
        }
        else {
            c = lca(a, b);
            printf("%d\n", max(query_max(a, c), query_max(b, c)));
        }
    }
    return 0;
}
posted @ 2016-12-30 19:33  ljt12138  阅读(190)  评论(0编辑  收藏  举报