线段树合并
线段树合并,字面意思,可以将两个线段树合并到一起。如果我们的dp需要将两个数组相加或相乘,亦或是一些其他的操作,那么我们可以将这两个数组上建的线段树合并到一起去,可以加速我们dp的速度。
线段树合并的复杂度是 \(nlogn\) 的感觉挺玄学,但确实能证明。大概是因为我们对线段树进行操作后一共只会产生 \(nlogn\) 个点,而每进行一次操作都会将点数减一,所以复杂度可证。
思考一道例题,P4556
考虑这些链上的加操作,我们可以使用树上差分,只修改u,v和lca,fa[lca]来进行修改操作。然后我们的每一个点将它的儿子的值加到它自己身上。
现在得到了优秀的 \(n^2\) 算法,考虑使用线段树合并来进行儿子往父亲的合并。可以将复杂度优化到 \(nlogn\) 。
那么如何线段树合并呢?
int merge(int a,int b)
{
if(!a || !b) return a|b;
ls[a] = merge(ls[a],ls[b]);
rs[a] = merge(rs[a],rs[b]);
push_up;
}
大概如此,更多操作需要根据题意进行。
下面是此题的完整代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
#define mid (l+r>>1)
using namespace std;
int read()
{
int a = 0,x = 1;char ch = getchar();
while(ch > '9' || ch < '0') {if(ch == '-') x = -1;ch = getchar();}
while(ch >= '0' && ch <= '9') {a = a*10 + ch-'0';ch = getchar();}
return a*x;
}
const int N=5e6+7,R=1e5;
int n,m;
int head[N],go[N],nxt[N],cnt,ans[N];
void add(int u,int v)
{
go[++cnt] = v;
nxt[cnt] = head[u];
head[u] = cnt;
}
int fa[N][31],dep[N],rt[N],tot,ls[N],rs[N],tre[N],siz[N];
void dfs1(int u)
{
dep[u] = dep[fa[u][0]] + 1;
for(int i = 1;i <= 30;i ++) fa[u][i] = fa[fa[u][i-1]][i-1];
for(int e = head[u];e;e = nxt[e]) {
int v = go[e];
if(v == fa[u][0]) continue;
fa[v][0] = u;dfs1(v);
}
}
int LCA(int a,int b)
{
if(dep[a] < dep[b]) swap(a,b);
for(int i = 30;i >= 0;i --) if(dep[fa[a][i]] >= dep[b]) a = fa[a][i];
if(a == b) return a;
for(int i = 30;i >= 0;i --) if(fa[a][i] != fa[b][i]) a = fa[a][i],b = fa[b][i];
return fa[a][0];
}
void pushup(int root) {
if(siz[ls[root]] >= siz[rs[root]]) tre[root] = tre[ls[root]];
else tre[root] = tre[rs[root]];
siz[root] = max(siz[ls[root]],siz[rs[root]]);
}
void modify(int &root,int l,int r,int p,int x)
{
if(!root) root = ++tot;
if(l == r && l == p) {siz[root] += x,tre[root] = l;return ;}
if(p <= mid) modify(ls[root],l,mid,p,x);
else modify(rs[root],mid+1,r,p,x);
pushup(root);
}
int merge(int a,int b,int l,int r)
{
if(!a || !b) return a|b;
if(l == r) {siz[a] += siz[b];return a;}
ls[a] = merge(ls[a],ls[b],l,mid);
rs[a] = merge(rs[a],rs[b],mid+1,r);
pushup(a);return a;
}
void dfs(int u)
{
for(int e = head[u];e;e = nxt[e]) {
int v = go[e];if(v == fa[u][0]) continue;
dfs(v);rt[u] = merge(rt[u],rt[v],1,R);
}
ans[u] = siz[rt[u]]?tre[rt[u]]:0;
}
int main()
{
// freopen("in.in","r",stdin);
// freopen("out.out","w",stdout);
n = read(),m = read();
for(int i = 1;i < n;i ++) {
int u = read(),v = read();
add(u,v);add(v,u);
}
dfs1(1);
for(int i = 1;i <= m;i ++) {
int x = read(),y = read(),z = read(),tmp = LCA(x,y);
modify(rt[x],1,R,z,1);modify(rt[y],1,R,z,1);
modify(rt[tmp],1,R,z,-1);if(fa[tmp][0]) modify(rt[fa[tmp][0]],1,R,z,-1);
}
dfs(1);
for(int i = 1;i <= n;i ++) printf("%d\n",ans[i]);
return 0;
}