树状数组求逆序对个数

首先对于树状数组,当前的理解是

对于一个1~n的序列,一共有n个前缀和,每个前缀的下标都有唯一的二进制分解形式

通过这个性质我们可以在分解前缀下标log的时间内,分解前缀加和的过程

加的时候,比如算1~10,我们知道10=2+8,先算9~10,长度为2,然后再算1~8,长度为8,分解完成,两步算出

那么递归地减小要算的前缀和,我们发现它总是对的

那么如何更新它呢,那我们就要看在统计和的过程中,要更新的位置对树状数组的哪些节点有贡献,很显然他一定对于

x-lowbit(x)进行若干次变换得到他当前的下标的值有贡献,那么它就对于cur+lowbit(cur)进行若干次迭代有贡献

直到它超过我们统计的范围N,因此范围N我们一定要知道,因此离散化使得N变小是很必要的,或者处理一些负值

 

如何统计逆序对呢

一种方法是,我们每插入一个数就计算插入时间在它前面但是值比他大的数有多少,很显然这是一个区间和sum(a+1,N)

另外一种方法是,按照原插入序列的倒序插入,这样我们需要检测的是时间在它后面但是值比它小的数有多少,很显然这是一个前缀和sum(a-1)

网上大部分都是第二种方法,这里我贴上我写的第一种方法

#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
const int maxn=1e5+7;
int N,w,t[maxn];
struct node{
    int id,v;node(){};node(int id,int v):id(id),v(v){};
};
node a[maxn];
int lowbit(int x){
    return x&-x;
}
void add(int n,int x){
    while(n<=N){
        t[n]+=x;
        n+=lowbit(n);
    }
}
int sum(int n){
    int ans=0;
    while(n){
        ans+=t[n];
        n-=lowbit(n);
    }
    return ans;
}
bool cmp1(node a,node b){
    return a.v<b.v;
}
bool cmp2(node a,node b){
    return a.id<b.id;
}
int main(){
    int n,x;scanf("%d",&n);
    for(int i=1;i<=n;++i){
        scanf("%d",&x);
        a[i]=node(i,x);
    }
    sort(a+1,a+1+n,cmp1);
    int cnt=1,st=1,pre=a[1].v;
    for(int i=2;i<=n;++i){
        while(i<=n&&a[i].v==pre) i++;
        for(int j=st;j<i;++j){
            a[j].v=cnt;
        }
        st=i;pre=a[i].v;
        cnt++;
    }
    for(int j=st;j<=n;++j) a[j].v=cnt;
    //for(int i=1;i<=n;++i) printf("%d,",a[i].v);printf("\n");
    N=cnt;
    sort(a+1,a+1+n,cmp2);int ans=0;
    for(int i=1;i<=n;++i){
        ans+=sum(N)-sum(a[i].v);
        add(a[i].v,1);
    }
    printf("%d\n",ans);
    return 0;
}

 第二种写法...

#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
const int maxn=1e5+7;
int N,w,t[maxn];
struct node{
    int id,v;node(){};node(int id,int v):id(id),v(v){};
};
node a[maxn];
int lowbit(int x){
    return x&-x;
}
void add(int n,int x){
    while(n<=N){
        t[n]+=x;
        n+=lowbit(n);
    }
}
int sum(int n){
    int ans=0;
    while(n){
        ans+=t[n];
        n-=lowbit(n);
    }
    return ans;
}
bool cmp1(node a,node b){
    return a.v<b.v;
}
bool cmp2(node a,node b){
    return a.id<b.id;
}
int main(){
    int n,x;scanf("%d",&n);
    for(int i=1;i<=n;++i){
        scanf("%d",&x);
        a[i]=node(i,x);
    }
    sort(a+1,a+1+n,cmp1);
    int cnt=1,st=1,pre=a[1].v;
    for(int i=2;i<=n;++i){
        while(i<=n&&a[i].v==pre) i++;
        for(int j=st;j<i;++j){
            a[j].v=cnt;
        }
        st=i;pre=a[i].v;
        cnt++;
    }
    for(int j=st;j<=n;++j) a[j].v=cnt;
    //for(int i=1;i<=n;++i) printf("%d,",a[i].v);printf("\n");
    N=cnt;
    sort(a+1,a+1+n,cmp2);int ans=0;
    for(int i=n;i>=1;--i){
        ans+=sum(a[i].v-1);
        add(a[i].v,1);
    }
    printf("%d\n",ans);
    return 0;
}

 

posted @ 2017-05-26 14:16  狡啮之仰  阅读(441)  评论(0编辑  收藏  举报