离散化权值线段树(求逆序对)
离散化权值线段树
Part 0:作者前言(废话)
以前其实早就学过用二路归并排序的方法求序列的逆序对,因为一直没有学会二路归并,所以逆序对一直不会做
前几天学了线段树,然后无意间在书上看到了“线段树求逆序对”这样的问题……
于是果断魔改一发线段树求一手逆序对。。。然后就有了这个博客
另外,祝贺我考试通过了,暂时不会AFO啦!!!
Part 1:逆序对是什么?
给出如下定义:
对于一个给定的序列\(a\),若序列中任意两个元素组成的二元组\(<a\)\(i\)\(,a\)\(j\)\(>\)满足:\(a\)\(i\)>\(a\)\(j\),且\(i\)<\(j\),则称这个二元组是序列\(a\)的一个逆序对
Part 2:魔改后线段树求逆序对的思路?
显然,我们知道用线段树很容易就可以维护区间和
所以第一步,我们在序列\(a\)的值域上建一个线段树,维护区间和,代表序列在任意一个区间中包含的元素数量
第二步,扫描序列\(a\)中的每一个元素,一个一个加入到权值线段树中,每次加入操作后,求出之前加入的数中,有几个比他大的(也就是查询\(a\)\(i\)\(-a\)\(n\)的区间和)
第三步,刚才第二步已经求出数列中以第\(i\)个数为第二元的逆序对总数了,\(ans\)累加统计答案即可
(没错就是这么简单)
Part 3:两极反转!!!
然而……你以为这么简单就结束了???
\(NONONO\)!!!!
这个做法最最最最大的弊端就是:建树!
我们是在序列a的值域上建立线段树,于是……就有了下面这种情况
(毒瘤)出题人给出序列a的每个元素均在长整型范围内,那么这个做法还没开始,就结束了。。。
那么,问题来了——\(How\) \(to\) \(deal\) \(with\) \(this\) \(f**king\) \(situation?\)
离散化(李散花)大法好啊!!!
我们注意到:我们只是利用到了序列\(a\)中元素的大小关系,所以没有必要存下\(a\)序列中的每个数是多少
举个生动形象的栗子:
给出两个序列:\(a\),\(b\)
\(a[4]=\){\(0x7f7f7f7f,1,2,3\)}
\(b[4]=\){\(4,1,2,3\)}
虽然\(4\)和\(0x7f7f7f7f\)差了好多,但是不影响这两个数列的逆序对数(都是3对)
那么我们的需求变为:把序列\(a\)中任意元素的值映射为大小不超过\(a\)的元素个数的另一个值,并且保持逆序对数不变
实现方法就是排序+去重后把原值映射为他的下标,就可以做到上述要求
排序和去重,当然了,伟大的\(algorithm\)库里早就给我们打造好了这两个函数:
\(sort()\)和\(unique()\),顺带一提\(sort()\)函数时间复杂度稳定为\(O(nlogn)\),他并不是简单的快排,\(sort\)源码很复杂,这里不多做解释(感谢\(zay\)学长告诉我\(QwQ\))
言归正传,我们离散化之后,按照上面的步骤来就可以啦!
Part 4:求逆序对源代码实现(加注释)
#include<algorithm>
#include<cstdio>
using namespace std;
typedef long long int ll;//十年OI一场空,不开long long见祖宗
const int maxn=500005;
int t[maxn],a[maxn],n;
ll ans;
void discretization(){//离散化
scanf("%d",&n);//n是数据总量
for(int i=1;i<=n;i++){
scanf("%d",a+i);//输入元素的同时copy一份,用来排序
t[i]=a[i];
}
sort(t+1,t+n+1);//从小到大快排
int m=unique(t+1,t+n+1)-t-1;//去重,unique返回去重后数组长度(这个说法极其不准确,只是便于理解,如果您想了解更多,百度搜索C++ unique)
for(int i=1;i<=n;i++)
a[i]=lower_bound(t+1,t+m+1,a[i])-t;//寻找a[i]的下标并且用下标覆盖掉原来的a[i]
}
struct sag{//线段树
int l,r;
ll v;
sag *ls,*rs;
inline void push_up() { v=ls->v+rs->v; }//维护区间和
inline bool in_range(const int L,const int R) { return (L<=l)&&(r<=R); }
inline bool outof_range(const int L,const int R) { return (r<L)||(R<l); }
void update(const int L,const int R){
if(in_range(L,R)) v++;//找到这个叶子结点,线段树中的这个元素数量+1
else if(!outof_range(L,R)){
ls->update(L,R);
rs->update(L,R);
push_up();//从下往上更新逆序对数
}
}
ll query(const int L,const int R){
if(in_range(L,R)) return v;
if(outof_range(L,R)) return 0;
return ls->query(L,R)+rs->query(L,R);//返回区间和
}
}*rot;
sag byte[maxn<<1],*pool=byte;//内存池建树
sag* New(const int L,const int R){
sag *u=pool++;
u->l=L,u->r=R;
if(L==R){
u->v=0;
u->ls=u->rs=NULL;
}else{
int Mid=(L+R)>>1;
u->ls=New(L,Mid);
u->rs=New(Mid+1,R);
u->push_up();
}
return u;
}
int main(){
discretization();
rot=New(1,n);//建立1-n的线段树
for(int i=1;i<=n;i++){//枚举每个元素
ans+=rot->query(a[i]+1,n);//因为相等元素不构成逆序对,所以a[i]+1
rot->update(a[i],a[i]);//该元素数量++
}
printf("%lld",ans);
return 0;
}
Part 5:再魔改
前面说了求逆序对既可以用归并排序,也可以用线段树,对吧?
那么现在我们再次对他魔改,让我们的线段树做点归并排序做不了的事:
昨天我遇到了这样一个题:
题意大概是这样的:
在一个序列里,定义三元组\(<a\)\(i\)\(,a\)\(j\)\(,a\)\(k\)\(>\),满足\(a\)\(i\)\(<a\)\(j\)\(<a\)\(k\)且\(i<j<k\),求这个序列中上述三元组的数量
原题网址:https://www.luogu.com.cn/problem/P1637
我们可以考虑枚举中间值\(k\),计算比\(k\)小的元素有\(lis\)个,比\(k\)大的元素有\(mor\)个,根据乘法原理,以k为第二元的三元组就有\(lis*mor\)个
发现这样的做法需要维护区间元素个数和,我们又双叒叕很自然的想起了上面提到过的离散化权值线段树
我们需要开两棵树,一颗正着读元素,用来存比\(k\)小、且在k前面出现的元素有\(lis\)个,记录,另一颗倒着读元素,存比\(k\)大、且在k后面出现的元素有\(mor\)个,记录(请读者想想为什么倒着读)
最后for跑一遍,运用加法原理,记录总值\(tot\),输出即可
Part 6:再魔改代码(请按照上面的思路自行理解,因为作者懒得加注释了QwQ)
#include<algorithm>
#include<cstdio>
using namespace std;
const int maxn=30005;
typedef long long int ll;
ll n,a[maxn],t[maxn],lis[maxn],mor[maxn],tot;
void discretization(){
scanf("%lld",&n);
for(int i=1;i<=n;i++){
scanf("%lld",a+i);
t[i]=a[i];
}
sort(t+1,t+1+n);
int m=unique(t+1,t+1+n)-t-1;
for(int i=1;i<=n;i++)
a[i]=lower_bound(t+1,t+m+1,a[i])-t;
}
struct sag{
int l,r,v;
sag *ls,*rs;
inline void push_up() { v=ls->v+rs->v; }
inline bool in_range(const int L,const int R) { return (L<=l)&&(r<=R); }
inline bool outof_range(const int L,const int R) { return (r<L)||(R<l); }
void update(const int L,const int R){
if(in_range(L,R)) v++;
else if(!outof_range(L,R)){
ls->update(L,R);
rs->update(L,R);
push_up();
}
}
ll query(const int L,const int R){
if(in_range(L,R)) return v;
if(outof_range(L,R)) return 0;
return ls->query(L,R)+rs->query(L,R);
}
};
sag byte[maxn<<2],*pool=byte;
sag* New(const int L,const int R){
sag *u=pool++;
u->l=L,u->r=R;
if(L==R){
u->v=0;
u->ls=u->rs=NULL;
}else{
int Mid=(L+R)>>1;
u->ls=New(L,Mid);
u->rs=New(Mid+1,R);
u->push_up();
}
return u;
}
int main(){
discretization();
sag *rot1=New(1,n);
sag *rot2=New(1,n);
for(int i=1;i<=n;i++){
lis[i]=rot1->query(1,a[i]-1);
rot1->update(a[i],a[i]);
}
for(int i=n;i>0;i--){
mor[i]=rot2->query(a[i]+1,n);
rot2->update(a[i],a[i]);
}
for(int i=1;i<=n;i++)
tot+=mor[i]*lis[i];
printf("%lld",tot);
return 0;
}