算法学习————线段树合并
一、线段树合并的思想
线段树合并,顾名思义,就是建立一棵新的线段树保存原有的两颗线段树的信息。
二、线段树合并的流程
假设我们合并到了两棵树的pos位置
-
如果a有pos位置,b没有,那么新的线段树pos位置赋成a,返回
-
如果b有pos位置,a没有,赋成b,返回
-
如果此时已经合并到两棵线段树的叶子节点了,就把b在pos的值加到a上,把新线段树上的pos位置赋成a,返回
-
递归处理左子树
-
递归处理右子树
-
用左右子树的值更新当前节点
-
将新线段树上的pos位置赋成a,返回
代码:
int merge(int a,int b,int l,int r){
if (!a) return b;
if (!b) return a;
int res = a;
if (l == r){
t[res].sum = t[a].sum+t[b].sum;
t[res].ans = l;
return res;
}
int mid = (l+r >> 1);
t[res].l = merge(t[a].l,t[b].l,l,mid);
t[res].r = merge(t[a].r,t[b].r,mid+1,r);
// cout<<l<<" "<<r<<endl;
// cout<<"merge = "<<a<<" "<<t[a].sum<<" "<<t[a].ans<<" "<<b<<" "<<t[b].sum<<" "<<t[b].ans<<endl;
// cout<<"res = "<<res<<" "<<t[res].l<<" "<<t[res].r<<endl;
pushup(res);
return res;
}
例题:CF600E Lomsat gelral
线段树怎么维护呢??
建一棵权值线段树,线段树上的每个节点维护两个值,当前区间颜色出现的最大次数,和出现次数为最大次数的颜色的和
这样每次更新的时候,如果左儿子的颜色出现的最大次数大,直接等于左儿子,但注意不要修改左右两个儿子的指针
如果右儿子的颜色出现的最大次数大,直接等于右儿子,如果相等,则把颜色的和更新
每次在dfs回溯的时候把儿子和自己的合并,最后再把自己加进去
代码:
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#define ll long long
#define B cout<<"Breakpoint"<<endl;
#define O(x) cout<<#x<<" "<<x<<endl;
#define o(x) cout<<#x<<" "<<x<<" ";
using namespace std;
int read(){
int x = 1,a = 0;char ch = getchar();
while (ch < '0'||ch > '9'){if (ch == '-') x = -1;ch = getchar();}
while (ch >= '0'&&ch <= '9'){a = a*10+ch-'0';ch = getchar();}
return x*a;
}
const int maxn = 2e5+10;
int n,col[maxn];
struct node{
int to,nxt;
}ed[maxn << 1];
int head[maxn],tot;
void add(int u,int to){
ed[++tot].to = to;
ed[tot].nxt = head[u];
head[u] = tot;
}
struct SEGTree{
int l,r,sum;
ll ans;
}t[maxn << 2];
void pushup(int x){
int l = t[x].l,r = t[x].r;
// cout<<"pushup = = "<<l<<" "<<r<<" "<<t[l].sum<<" "<<t[x].ans<<" "<<t[r].sum<<" "<<t[r].ans<<endl;
if (t[l].sum > t[r].sum){
t[x].ans = t[l].ans;
t[x].sum = t[l].sum;
}
if (t[r].sum > t[l].sum){
t[x].ans = t[r].ans;
t[x].sum = t[r].sum;
}
if (t[l].sum == t[r].sum){
t[x].sum = t[l].sum;
t[x].ans = t[l].ans+t[r].ans;
}
// cout<<"pushup = "<<t[x].sum<<" "<<t[x].ans<<endl;
}
int cnt;
void modify(int &x,int lst,int l,int r,int p,int k){
if (!x) x = ++cnt;
if (l == r){
t[x].sum = t[lst].sum+1;
t[x].ans = l;
// cout<<"modify = "<<x<<" "<<l<<" "<<t[x].sum<<" "<<t[x].ans<<endl;
return;
}
int mid = (l+r >> 1);
if (p <= mid) t[x].r = t[lst].r,modify(t[x].l,t[lst].l,l,mid,p,k);
else t[x].l = t[lst].l,modify(t[x].r,t[lst].r,mid+1,r,p,k);
// cout<<x<<" "<<t[x].l<<" "<<t[x].r<<endl;
pushup(x);
}
int merge(int a,int b,int l,int r){
if (!a) return b;
if (!b) return a;
int res = a;
if (l == r){
t[res].sum = t[a].sum+t[b].sum;
t[res].ans = l;
return res;
}
int mid = (l+r >> 1);
t[res].l = merge(t[a].l,t[b].l,l,mid);
t[res].r = merge(t[a].r,t[b].r,mid+1,r);
// cout<<l<<" "<<r<<endl;
// cout<<"merge = "<<a<<" "<<t[a].sum<<" "<<t[a].ans<<" "<<b<<" "<<t[b].sum<<" "<<t[b].ans<<endl;
// cout<<"res = "<<res<<" "<<t[res].l<<" "<<t[res].r<<endl;
pushup(res);
return res;
}
int root[maxn];
ll ans[maxn];
void dfs(int x,int fa){
for (int i = head[x];i;i = ed[i].nxt){
int to = ed[i].to;
if (to == fa) continue;
dfs(to,x);
root[x] = merge(root[x],root[to],1,n);
}
modify(root[x],root[x],1,n,col[x],1);
ans[x] = t[root[x]].ans;
}
int main(){
n = read();
for (int i = 1;i <= n;i++) col[i] = read();
for (int i = 1;i < n;i++){
int x = read(),y = read();
add(x,y),add(y,x);
}
dfs(1,0);
for (int i = 1;i <= n;i++) printf("%lld ",ans[i]);
return 0;
}