统计逆序对的两种解法
统计逆序对的两种解法
归并排序(mergeSort)
逆序对定义
\(i<j\) 但\(a[i]>a[j]\),假设我们分别使得通过mergeSort使得左右半边有序
即\(a[1]...a[mid]\) 递增, \(a[mid+1]....a[n]\)递增,我们需要通过merge操作,完成整个的排序和新增逆序对的计数,较小值出现在左半边记为 a[i],出现在右半边即为 a[j],那么每次出现在右半边,意味左半边比a[i]大的数都比a[j]大,由此可以统计逆序对
HDU1394
代码实现
#include<bits/stdc++.h>
using namespace std;
#define db(x) cout<<"["<<#x<<"]="<<x<<endl
/*
归并排序求逆序对+规律
*/
const int maxn = 5010;
int a[maxn];
int c[maxn];
int b[maxn];
//mergeSort
int n;
int merge1(int* a,int l1,int r1,int l2,int r2){
int p1=l1,p2 = l2;
int t = 0;
int cnt = 0;
//db(a[p1]);db(a[p2]);
while(p1<=r1&&p2<=r2){
if(a[p1]<a[p2]){
b[t] = a[p1];
p1++;
t++;
}
else{//a[p1]>a[p2]; a[p2] 小于 p1...r1所有数
b[t] = a[p2];
cnt+=(r1-p1+1);
//db(cnt);
p2++;
t++;
}
}
while(p1<=r1){b[t]=a[p1];p1++,t++;}
while(p2<=r2){b[t]=a[p2];p2++,t++;}
for(int k=0;k<t;k++){
a[l1+k] = b[k];
}
//db(cnt);db(l1);db(r1);db(l2);db(r2);
return cnt;
}
int mergeSort(int* a,int l,int r){
if(l==r) return 0;
int cnt = 0;
int mid = (l+r)>>1;
cnt+=mergeSort(a,l,mid);
cnt+=mergeSort(a,mid+1,r);
cnt+=merge1(a,l,mid,mid+1,r);
return cnt;
}
int main(){
while(cin>>n){
for(int i=0;i<n;i++){cin>>a[i];c[i]=a[i];}
int tmp = mergeSort(a,0,n-1);
//db(tmp);
int mint = tmp;
for(int i=0;i<n-1;i++){
tmp +=n-1-2*c[i];
//db(tmp);
mint = min(tmp,mint);
}
cout<<mint<<endl;
}
return 0;
}
线段树
线段树的解法非常简单,每次插入a[i] ,同时对a[i]+1....n-1进行计数;
此时要求元素范围不能太大,当然如果是在\(1..n\)之间,那么非常理想
代码实现
#include<bits/stdc++.h>
using namespace std;
#define db(x) cout<<"["<<#x<<"]="<<x<<endl
const int maxn = 5e3+10;
struct node{
int l,r,num; //num维护的信息是节点插入的区间插入节点的数目
}tr[maxn<<2];//线段树
int a[maxn];
void build(int n,int x,int y){//n是根节点下标,x,y是维护的区间范围
tr[n].l = x,tr[n].r = y;
tr[n].num = 0;
if(x==y) return ;
int mid = (x+y)>>1;//no over
build(n<<1,x,mid);
build(n<<1|1,mid+1,y);
tr[n].num = tr[n<<1].num+tr[n<<1|1].num;
}
void modify(int n,int p){//跟新区间单点p的信息
int l = tr[n].l, r= tr[n].r;
if(l==r&&l==p){//found
tr[n].num=1;
return ;
}
int mid = (l+r)>>1;
if(p<=mid) modify(n<<1,p);
if(p>mid) modify(n<<1|1,p);
tr[n].num = tr[n<<1].num+tr[n<<1|1].num;
}
int query(int n,int x,int y){
int l = tr[n].l , r=tr[n].r;
int mid = (l+r)>>1;
int ans = 0;
if(l>=x&&r<=y){//x,y覆盖了l,r
return tr[n].num;
}
if(x<=mid) ans+=query(n<<1,x,y);
if(y>mid) ans+=query(n<<1|1,x,y);
return ans;
}
int n;
int main(){
while(cin>>n){
build(1,0,n-1);
int ans = 0;
for(int i=0;i<n;i++){
cin>>a[i];
int t=query(1,a[i]+1,n-1);
//db(t);
ans+=t;
modify(1,a[i]);
}
int mint = ans;
//db(ans);
for(int i=0;i<n-1;i++){
ans+=(n-1-2*a[i]);
mint = min(ans,mint);
}
cout<<mint<<endl;
}
return 0;
}