浅谈求逆序对
这里我们从易到难来介绍三种求逆序对数的方法。
top1:暴力枚举
时间复杂度:\(O(n^2)\)
emmmmm……
这好像真的没什么可说的……
code:
for(int i = 1; i < n; i++)
for(int j = i + 1; j <= n; j++)
if(a[i] > a[j]) ans++;
top2:归并排序
时间复杂度:\(O(n\log{n})\)
逆序对怎么还和排序扯上关系了呢qwq?
是这样的,在归并排序中,我们每次合并两个有序的序列\(a,b\),如果\(a_i>b_j\),由于a和b都是有序的,那么\(a_{i-l_1}>b_j\),所以我们可以直接更新答案:\(ans=ans+l_1-a_i+1\)
code:
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
long long a[510000],n,b[510000],tj=0;
void gb(int l,int r)
{
if(l==r) return ;
int mid=(r+l)/2,i;
gb(l,mid);
gb(mid+1,r);
int l1,l2,r1;
l1=l;
l2=l;
r1=mid+1;
while(l1<=mid&&r1<=r)
{
if(a[l1]>a[r1])
{
tj=tj+(mid-l1+1);
b[l2++]=a[r1++];
}
else
{
b[l2++]=a[l1++];
}
}
if(l1<=mid) for(i=l1;i<=mid;i++) b[l2++]=a[i];
if(r1<=r) for(i=r1;i<=r;i++) b[l2++]=a[i];
for(i=l;i<=r;i++) a[i]=b[i];
}
int main()
{
int i;
cin>>n;
for(i=1;i<=n;i++) scanf("%lld",&a[i]);
gb(1,n);
cout<<tj;
return 0;
}
top3:树状数组
时间复杂度:\(O(n\log{n})\)
树状数组求逆序对数大概是最难理解的一种了。
大体思路是我们把下标按照其对应的元素大小从大到小排序以后通过树状数组来快速处理每一个元素对答案的贡献。
首先,我们把给定的序列的下标从大到小排序。需要注意的是,对于指向元素大小相等的坐标,我们应按坐标从大到小排序。(其实这一步就是一个离散化,方便后面树状数组的处理)
...
bool cmp(int x,int y)
{
if(a[x] != a[y]) return a[x] > a[y];
return x > y;
}
int main()
{
n = read();
for(int i = 1; i <= n; i++) a[i] = read(),b[i] = i;
sort(b + 1, b + n + 1, cmp);
...
}
然后,我们按照排完序的坐标序列的顺序来更新答案。
具体思路是这样的:
对于下标\(i\),我们在树状数组中使\(位置i\)加\(1\),表示\(a[i]\)和\(a[j](j\in[i+1,n],且在下标序列中j比i靠后)\)会构成一组逆序对,而\(i\)指向的元素\(x\)一共会和\(query(i-1)\)个\(a[k](k\in[1,i-1]且在下标序列中k比i靠前)\)构成逆序对。
因为如果在下标序列中,靠前的下标指向的元素一定大于等于靠后的下标指向的元素,而在下标序列中比当前遍历到的下标靠前的下标一定会被优先处理并完成关于该下标答案的统计,所以这样做是正确的。
code:
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<algorithm>
#define ll long long
#define INF 0x7fffffff
#define re register
#define int long long
using namespace std;
int read()
{
register int x = 0,f = 1;register char ch;
ch = getchar();
while(ch > '9' || ch < '0'){if(ch == '-') f = -f;ch = getchar();}
while(ch <= '9' && ch >= '0'){x = x * 10 + ch - 48;ch = getchar();}
return x * f;
}
int n,a[1000005],shu[1000005],b[1000005],ans;
int lowbit(int x){return x & (-x);}
bool cmp(int x,int y)
{
if(a[x] != a[y]) return a[x] > a[y];
return x > y;
}
void updata(int x)
{
while(x <= n)
{
shu[x]++;
x += lowbit(x);
}
}
int query(int x)
{
int cnt = 0;
while(x > 0)
{
cnt += shu[x];
x -= lowbit(x);
}
return cnt;
}
signed main()
{
n = read();
for(int i = 1; i <= n; i++) a[i] = read(),b[i] = i;
sort(b + 1, b + n + 1, cmp);
for(int i = 1; i <= n; i++)
{
updata(b[i]);
ans = ans + query(b[i] - 1);
}
printf("%lld\n",ans);
return 0;
}