P4556 [Vani有约会] 雨天的尾巴 /【模板】线段树合并
P4556 [Vani有约会] 雨天的尾巴 /【模板】线段树合并
在这题里面讲一下线段树合并。顾名思义就是把多个线段树合并成一个。
显然完全二叉线段树(也就是普通线段树)是无法更高效的合并的,只能把所有节点加起来建个新树。但是在动态开点线段树中,有时候一个树只有几条链,这时候我们就是可以使用线段树合并了。
其核心就是只遍历两棵树重叠的部分实现合并。具体的,对于两棵线段树 \(p1\),\(p2\),我们要把 \(p2\) 合并到 \(p1\) 上。不妨先考虑左儿子,假如两个线段树都有左儿子,那么就往下遍历;假如只有 \(p1\) 有左儿子,显然不需要修改;假如只有 \(p2\) 有左儿子,\(p1\) 没有记录信息,那么只需要把 \(p1\) 左儿子的指针改成 \(p2\) 左儿子的指针即可。复杂度不会严格证明,一般所有线段树的总节点个数为 \(O(n\log n)\),合并时遍历到的只有重叠部分,非重叠部分没有遍历,所以每个线段树的节点至多被遍历一次,所以时间复杂度是 \(O(n\log n)\)。
void mg(int p1, int p2, int l, int r) {
if(l == r) {
.....
return;
}
int mid = (l + r) >> 1;
if(t[p1].ls && t[p2].ls) mg(t[p1].ls, t[p2].ls, l, mid);
else if(t[p2].ls) t[p1].ls = t[p2].ls;
if(t[p1].rs && t[p2].rs) mg(t[p1].rs, t[p2].rs, mid + 1, r);
else if(t[p2].rs) t[p1].rs = t[p2].rs;
pushup(p1, t[p1].ls, t[p1].rs);
}
回到这题,容易想到树上差分,然后把操作拆成 \(4\) 个单点修改,离线在树上线段树合并即可。
#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define fi first
#define se second
#define pb push_back
typedef long long i64;
const int N = 1e5 + 10;
int n, m, cnt;
int dep[100010], anc[100010][21], h[100010];
struct node{
int to, nxt;
} e[200010];
void add(int u, int v) {
e[++cnt].to = v;
e[cnt].nxt = h[u];
h[u] = cnt;
}
void dfs(int u, int fa) {
for(int i = h[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa) continue;
anc[v][0] = u;
dep[v] = dep[u] + 1;
dfs(v, u);
}
}
void init() {
for(int j = 1; j <= 19; j++) {
for(int i = 1; i <= n; i++) {
anc[i][j] = anc[anc[i][j - 1]][j - 1];
}
}
}
int lca(int u, int v) {
if(dep[u] < dep[v]) std::swap(u, v);
for(int i = 19; i >= 0; i--) if(dep[anc[u][i]] >= dep[v]) u = anc[u][i];
if(u == v) return u;
for(int i = 19; i >= 0; i--) if(anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
return anc[u][0];
}
int tot;
std::vector<pii> ve[N];
struct seg {
int ls, rs, mx, cnt;
} t[40 * N];
void pushup(int u, int ls, int rs) {
if(t[ls].cnt >= t[rs].cnt && t[ls].cnt) {
t[u].cnt = t[ls].cnt;
t[u].mx = t[ls].mx;
} else if(t[ls].cnt <= t[rs].cnt && t[rs].cnt) {
t[u].cnt = t[rs].cnt;
t[u].mx = t[rs].mx;
} else {
t[u].cnt = t[u].mx = 0;
}
return;
}
void ins(int &u, int l, int r, int x, int y) {
if(!u) u = ++tot;
if(l == r) {
t[u].cnt += y;
t[u].mx = t[u].cnt ? l : 0;
return;
}
int mid = (l + r) >> 1;
if(x <= mid) ins(t[u].ls, l, mid, x, y);
else ins(t[u].rs, mid + 1, r, x, y);
pushup(u, t[u].ls, t[u].rs);
}
void mg(int p1, int p2, int l, int r) {
if(l == r) {
t[p1].cnt += t[p2].cnt;
t[p1].mx = t[p1].cnt ? l : 0;
return;
}
int mid = (l + r) >> 1;
if(t[p1].ls && t[p2].ls) mg(t[p1].ls, t[p2].ls, l, mid);
else if(t[p2].ls) t[p1].ls = t[p2].ls;
if(t[p1].rs && t[p2].rs) mg(t[p1].rs, t[p2].rs, mid + 1, r);
else if(t[p2].rs) t[p1].rs = t[p2].rs;
pushup(p1, t[p1].ls, t[p1].rs);
}
int ans[N];
void dfs2(int u, int fa) {
for(int i = h[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa) continue;
dfs2(v, u);
mg(u, v, 1, 1e5);
}
for(auto x : ve[u]) ins(u, 1, 1e5, x.fi, x.se);
ans[u] = t[u].mx;
}
void Solve() {
std::cin >> n >> m;
tot = n;
for(int i = 1; i < n; i++) {
int u, v;
std::cin >> u >> v;
add(u, v), add(v, u);
}
dep[1] = 1;
dfs(1, 0);
init();
while(m--) {
int u, v, w;
std::cin >> u >> v >> w;
int rt = lca(u, v);
ve[u].push_back({w, 1}), ve[v].push_back({w, 1});
ve[rt].push_back({w, -1}), ve[anc[rt][0]].push_back({w, -1});
}
dfs2(1, 0);
for(int i = 1; i <= n; i++) std::cout << ans[i] << "\n";
}
int main() {
Solve();
return 0;
}