线段树合并
线段树合并
前置芝士
动态开点线段树和权值线段树
乍一看,线段树合并和上面那两个奇怪的东西有什么关系。
其实,线段树合并的全称为动态开点权值线段树合并( 雾
如果对上面那两个奇怪的东西不理解可点开链接进行搜索(大雾
优点
动态开点线段树有着一些优点,比如说当你让某个节点继承另一个节点的左儿子或者右儿子的时候,你可以不用新建一棵线段树,而是直接将该节点的左右儿子赋成那个节点的左右儿子就行了,总之就是空间上有一定的优越性。
权值线段树能代替平衡树做一些求 kk 大、排名、找前驱后继的操作。(显然是我不会平衡树,如果你会平衡树当我没说)
概念
线段树合并,顾名思义,就是建立一棵新的线段树保存原有的两颗线段树的信息。
合并方式主要如下:
如果不能理解,可以往下翻看代码
所以问题来了,复杂度是多少?
复杂度
(转自洛谷日报)
代码
合并(好像也就这一个操作,别的和动态开点,权值线段树一样)
//tot在这里面就是记录编号的 ls和rs为lson和rson
int merge(int a,int b,int l,int r){
if(!a) return b;
if(!b) return a;
//可写成 if(!a||!b) return a|b;
if(l==r)
{
sum[++tot]=sum[a]+sum[b];
return tot;
}
int mid=(l+r)>>1;
//这里省略若干操作,因题而异
ls[++tot]=merge(ls[a],ls[b],l,mid);
rs[tot]=merge(rs[a],rs[b],mid+1,r);
sum[tot]=sum[ls[tot]]+sum[rs[tot]];
return tot;
}
例题1
CF600E
思路
线段树合并。权值线段树覆盖颜色1−>100000,用sum1−>100000,用sum表示颜色最多出现的次数,ans表示答案。分3种情况push_up即可。
- 左右子树sum相等
- 左边>右边
- 左边<右边
dfs的时merge一下即可。
代码
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<vector>
#define mid (l+r>>1)
#define lson tr[i].l
#define rson tr[i].r
#define int long long
using namespace std;
const int maxn=100010;
int col[maxn];
int n,cnt;
int rt[maxn];
vector<int>g[maxn];
int anss[maxn];
struct node{ //这里使用的存储方式不是ls[],rs[]而是结构体
int l,r,sum,ans;//sum为最多出现颜色次数 ans为最多出现编号
}tr[maxn*40];
inline void push_up(int i)
{
if(tr[lson].sum==tr[rson].sum)
{
tr[i].sum=tr[lson].sum;
tr[i].ans=tr[lson].ans+tr[rson].ans;
}
else if(tr[lson].sum<tr[rson].sum)
{
tr[i].sum=tr[rson].sum;
tr[i].ans=tr[rson].ans;
}
else
{
tr[i].sum=tr[lson].sum;
tr[i].ans=tr[lson].ans;
}
}
inline void update(int &i,int l,int r,int pos)
{
if(!i)i=++cnt;
if(l==r)
{
tr[i].sum++;tr[i].ans=l;
return;
}
if(pos<=mid)update(lson,l,mid,pos);
else update(rson,mid+1,r,pos);
push_up(i);
}
inline int merge(int a,int b,int l,int r)
{
if(!a||!b)return a+b;
if(l==r)
{
tr[a].sum+=tr[b].sum;tr[a].ans=l;
return a;
}
tr[a].l=merge(tr[a].l,tr[b].l,l,mid);
tr[a].r=merge(tr[a].r,tr[b].r,mid+1,r);
push_up(a);
return a;
}
inline void dfs(int now,int fa)
{
for(int i=0;i<g[now].size();i++)
{
if(g[now][i]==fa)continue;
dfs(g[now][i],now);
merge(rt[now],rt[g[now][i]],1,100000);
}
update(rt[now],1,100000,col[now]);
anss[now]=tr[rt[now]].ans;
}
signed main()
{
ios::sync_with_stdio(false); //cf上用%lld输入好像不大行,所以改成cin的快读了
cin>>n;
for(int i=1;i<=n;i++)
{
cin>>col[i];
rt[i]=i;
cnt++;
}
for(int i=1;i<n;i++)
{
int from,to;
cin>>from>>to;
g[from].push_back(to);
g[to].push_back(from);
}
dfs(1,0);
for(int i=1;i<=n;i++)
{
cout<<anss[i]<<" ";
}
return 0;
}
(这道题除了线段树合并,还可以用树上启发式合并/dsu on tree 来解决,有兴趣的读者可自行搜索)
例题2
P3521 [POI2011]ROT-Tree Rotations
思路
这道题主要就是权值线段树合并的一个过程。我们对每个叶子结点开一个权值线段树,然后逐步合并。
考虑到一件事情:如果在原树有一个根节点 \(x\),和其左儿子 \(ls\) ,右儿子 \(rs\) 。我们要合并的是 \(ls\) 的权值线段树和 \(rs\) 的权值线段树,得到 \(x\) 的所有叶节点的权值线段树。
发现交换 \(ls\) 和 \(rs\) 并不会对原树更上层之间的逆序对产生影响,于是我们只需要每次合并都让逆序对最少。
于是我们的问题转化为了给定两个权值线段树,问把它们哪个放在左边可以使逆序对个数最小,为多少。
考虑我们合并到一个节点,其权值范围为 \([l,r]\) ,中点为 \(mid\) 。这个时候我们有两棵树,我们要分别计算出某棵树在左边的时候和某棵树在右边的时候的逆序对个数。事实上我们只需要处理权值跨过中点 \(mid\) 的逆序对,那么所有的逆序对都会在递归过程中被处理仅一次(类似一个分治的过程)。而我们这个时候可以轻易的算出两种情况的逆序对个数,不交换的话是左边那棵树的右半边乘上右边那棵树的的左半边的大小;交换的话则是左边那棵树的左半边乘上左边那棵树的的右半边的大小。
然后每次合并由于都可以交换左右子树,我们就把这次合并中交换和不交换的情况计算一下,取最小值累积就可以了。
空间复杂度:\(O(n \log n)\),时间复杂度 \(O(n \log n)\)。
#include<cstdio>
#define rep(i, a, b) for (register int i=(a); i<=(b); ++i)
#define per(i, a, b) for (register int i=(a); i>=(b); --i)
using namespace std;
const int N=6000005;
long long min(long long a, long long b){return a<b?a:b;}
int ls[N], rs[N], val[N], n, tot;
long long ans, ans1, ans2;
inline int read()
{
int x=0,f=1;char ch=getchar();
for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
for (;ch>='0'&&ch<='9';ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
return x*f;
}
int New(int l, int r, int x)
{
val[++tot]=1;
if (l==r) return tot;
int mid=l+r>>1, node=tot;
if (x<=mid) ls[node]=New(l, mid, x);
else rs[node]=New(mid+1, r, x);
return node;
}
int merge(int l, int r, int u, int v)
{
if (!u || !v) return u+v;
if (l==r) {val[u]=val[u]+val[v]; return u;}
int mid=(l+r)>>1, node=u;
ans1+=1ll*val[rs[u]]*val[ls[v]];
ans2+=1ll*val[ls[u]]*val[rs[v]];
ls[node]=merge(l, mid, ls[u], ls[v]);
rs[node]=merge(mid+1, r, rs[u], rs[v]);
val[node]=val[ls[node]]+val[rs[node]];
return node;
}
int dfs()
{
int v=read();
if (v) return New(1, n, v);
int node=merge(1, n, dfs(), dfs());
ans+=min(ans1, ans2); ans1=ans2=0;
return node;
}
int main()
{
n=read(); dfs(); printf("%lld\n", ans);
return 0;
}