[学习笔记] 线段树合并
知识讲解
将两颗线段树(一般是权值线段树)合并,一般要用到动态开点。
合并两颗线段树
这是离线的做法,会破坏\(y\)的结构,对于这种问题,我们一般是离线的。
1.对于两棵线段树都有的节点,新的线段树的该节点值为两者和。
2.对于某一棵线段树有的节点,新的线段树保存该节点的值。
3.然后对左右子树递归处理。
时间复杂度均摊\(O(\log n)\)
inline void merge (int &x, int &y, int l, int r) {
if (!x || !y) { x = x + y; return;}
if (l == r) { do something; return;}
int mid = (l + r) >> 1;
merge (ls(x), ls(y), l, mid);
merge (rs(x), rs(y), mid + 1, r);
push_up (x);
}
在线做法
新开节点,这样不会破坏\(y\)树的结构,但是空间会很大。
inline int merge(int x, int y, int l, int r) {
if(x == 0 || y == 0) {
x = x + y;
return x;
}
int p = ++cnt;
if(l == r) {
dosome thing
return p;
}
int mid = (l + r) >> 1;
ls(p) = merge(ls(x), ls(y), l, mid);
rs(p) = merge(rs(x), rs(y), mid + 1, r);
pushup(p);
return p;
}
易错
& 符号不要忘记加
\(p==0\)的时候要考虑清楚。
例题讲解
熟悉的\(\text{CF600E}\)
我们可以为每一个子节点都开一个权值线段树,显然很多空间浪费,所以要动态开点。
然后你想要获得\(u\)的信息,就得知道\(v\)的子树信息,可以将子树的信息合并。
时间复杂度\(O(n \log n)\) 空间复杂度\(O(n \log n)\)
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <set>
#include <map>
#include <queue>
using namespace std;
template <typename T> void chkmax(T &x, T y) {x = x > y ? x : y;}
template <typename T> void chkmin(T &x, T y) {x = x > y ? y : x;}
typedef long long ll;
const int INF = 2139062143;
#define DEBUG(x) std::cerr << #x << " = " << x << std::endl
template <typename T> void read (T &x) {
x = 0; bool f = 1; char ch;
do {ch = getchar(); if (ch == '-') f = 0;} while (ch > '9' || ch < '0');
do {x = x * 10 + ch - '0'; ch = getchar();} while (ch >= '0' && ch <= '9');
x = f ? x : -x;
}
template <typename T> void write (T x) {
if (x < 0) x = ~x + 1, putchar ('-');
if (x > 9) write (x / 10);
putchar (x % 10 + '0');
}
const int N = 1e5 + 7;
struct EDGE {
int to, next;
} edge[N << 1];
struct Segment_Tree {
int ls, rs;
ll mx, sum;
#define ls(p) (t[p].ls)
#define rs(p) (t[p].rs)
} t[N * 20];
int n, E, cnt, c[N], head[N], root[N];
ll ans[N];
inline void addedge (int u, int v) {
edge[++E].to = v;
edge[E].next = head[u];
head[u] = E;
}
inline void push_up (int x) {
t[x].mx = max (t[ls(x)].mx, t[rs(x)].mx);
t[x].sum = 0;
if (t[x].mx == t[ls(x)].mx) t[x].sum += t[ls(x)].sum;
if (t[x].mx == t[rs(x)].mx) t[x].sum += t[rs(x)].sum;
}
inline void merge (int &x, int &y, int l, int r) {
if (!x || !y) {
x = x + y;
return;
}
if (l == r) {
t[x].mx += t[y].mx;
t[x].sum = l;
return;
}
int mid = (l + r) >> 1;
merge (ls(x), ls(y), l, mid);
merge (rs(x), rs(y), mid + 1, r);
push_up (x);
}
inline void insert (int &x, int l, int r, int pos) {
if (!x) x = ++ cnt;
if (l == r) {
t[x].mx = 1;
t[x].sum = l;
return;
}
int mid = (l + r) >> 1;
if (pos <= mid) insert (ls(x), l, mid, pos);
else insert (rs(x), mid + 1, r, pos);
push_up (x);
}
inline void dfs (int u, int f) {
insert (root[u], 1, n, c[u]);
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].to;
if (v == f) continue;
dfs (v, u);
merge (root[u], root[v], 1, n);
}
ans[u] = t[root[u]].sum;
}
int main () {
read (n);
for (int i = 1; i <= n; i ++ ) read (c[i]);
for (int i = 1, u, v; i < n; i ++ ) {
read (u); read (v);
addedge (u, v);
addedge (v, u);
}
dfs (1, 0);
for (int i = 1; i <= n; i ++ ) printf ("%lld ", ans[i]);
return 0;
}