[学习笔记] 线段树合并

知识讲解

将两颗线段树(一般是权值线段树)合并,一般要用到动态开点。

合并两颗线段树

这是离线的做法,会破坏\(y\)的结构,对于这种问题,我们一般是离线的。

1.对于两棵线段树都有的节点,新的线段树的该节点值为两者和。
2.对于某一棵线段树有的节点,新的线段树保存该节点的值。
3.然后对左右子树递归处理。

时间复杂度均摊\(O(\log n)\)

inline void merge (int &x, int &y, int l, int r) {
    if (!x || !y) { x = x + y; return;}
    if (l == r) { do something; return;}
    int mid = (l + r) >> 1;
    merge (ls(x), ls(y), l, mid);
    merge (rs(x), rs(y), mid + 1, r);
    push_up (x);
}

在线做法

新开节点,这样不会破坏\(y\)树的结构,但是空间会很大。

inline int merge(int x, int y, int l, int r) {
    if(x == 0 || y == 0) {
        x = x + y;
        return x;
    }
    int p = ++cnt;
    if(l == r) { 
      dosome thing
      return p;
    }
    int mid = (l + r) >> 1;
    ls(p) = merge(ls(x), ls(y), l, mid);
    rs(p) = merge(rs(x), rs(y), mid + 1, r);
    pushup(p);
    return p;
}

易错

& 符号不要忘记加
\(p==0\)的时候要考虑清楚。

例题讲解

熟悉的\(\text{CF600E}\)
我们可以为每一个子节点都开一个权值线段树,显然很多空间浪费,所以要动态开点。
然后你想要获得\(u\)的信息,就得知道\(v\)的子树信息,可以将子树的信息合并。
时间复杂度\(O(n \log n)\) 空间复杂度\(O(n \log n)\)

#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <set>
#include <map>
#include <queue>

using namespace std;

template <typename T> void chkmax(T &x, T y) {x = x > y ? x : y;}
template <typename T> void chkmin(T &x, T y) {x = x > y ? y : x;}

typedef long long ll;

const int INF = 2139062143;

#define DEBUG(x) std::cerr << #x << " = " << x << std::endl

template <typename T> void read (T &x) {
    x = 0; bool f = 1; char ch;
    do {ch = getchar(); if (ch == '-') f = 0;} while (ch > '9' || ch < '0');
    do {x = x * 10 + ch - '0'; ch = getchar();} while (ch >= '0' && ch <= '9');
    x = f ? x : -x;
}

template <typename T> void write (T x) {
    if (x < 0) x = ~x + 1, putchar ('-');
    if (x > 9) write (x / 10);
    putchar (x % 10 + '0');
}

const int N = 1e5 + 7;

struct EDGE {
    int to, next;
} edge[N << 1];

struct Segment_Tree {
    int ls, rs;
    ll mx, sum;
    #define ls(p) (t[p].ls)
    #define rs(p) (t[p].rs)
} t[N * 20];

int n, E, cnt, c[N], head[N], root[N];
ll ans[N];

inline void addedge (int u, int v) {
    edge[++E].to = v;
    edge[E].next = head[u];
    head[u] = E;
}

inline void push_up (int x) {
    t[x].mx = max (t[ls(x)].mx, t[rs(x)].mx);
    t[x].sum = 0;
    if (t[x].mx == t[ls(x)].mx) t[x].sum += t[ls(x)].sum;
    if (t[x].mx == t[rs(x)].mx) t[x].sum += t[rs(x)].sum;
}

inline void merge (int &x, int &y, int l, int r) {
    if (!x || !y) {
        x = x + y;
        return;
    }
    if (l == r) {
        t[x].mx += t[y].mx;
        t[x].sum = l;
        return;
    }
    int mid = (l + r) >> 1;
    merge (ls(x), ls(y), l, mid);
    merge (rs(x), rs(y), mid + 1, r);
    push_up (x);
}

inline void insert (int &x, int l, int r, int pos) {
    if (!x) x = ++ cnt;
    if (l == r) {
        t[x].mx = 1;
        t[x].sum = l;
        return;
    }
    int mid = (l + r) >> 1;
    if (pos <= mid) insert (ls(x), l, mid, pos);
    else insert (rs(x), mid + 1, r, pos);
    push_up (x);
}

inline void dfs (int u, int f) {
    insert (root[u], 1, n, c[u]);
    for (int i = head[u]; i; i = edge[i].next) {
        int v = edge[i].to;
        if (v == f) continue;
        dfs (v, u);
        merge (root[u], root[v], 1, n);
    }
    ans[u] = t[root[u]].sum;
}

int main () {
    read (n);
    for (int i = 1; i <= n; i ++ ) read (c[i]);
    for (int i = 1, u, v; i < n; i ++ ) {
        read (u); read (v);
        addedge (u, v);
        addedge (v, u);
    }
    dfs (1, 0);
    for (int i = 1; i <= n; i ++ ) printf ("%lld ", ans[i]);
    return 0;
}
posted @ 2020-02-06 23:15  Hock  阅读(136)  评论(0编辑  收藏  举报