【BZOJ3295】动态逆序对(CQOI2011)-CDQ分治:三维偏序

测试地址:动态逆序对
做法:本人这几天学习了CDQ分治思想,感觉还是比较难懂,于是找到了比较好理解的经典应用——三维偏序问题来加深理解。
这题首先需要把问题转化为三维偏序问题,然后再使用CDQ分治解决。
首先这个题目是将元素一个一个删除,在每次删除之前询问逆序对数,从这个方面来看好像无法下手,那么我们不如反过来,看成是将元素一个一个插入,在每次插入之后询问逆序对数。那么每个元素我们就可以使用一个三维坐标(xi,yi,zi)来表示,其中xi指元素的插入时间(以插入先后顺序标号为1~M,一开始就在的标号为0),yi指元素在排列中的位置zi指元素的。那么对于一个点(xi,yi,zi),如果存在newi个点(xj,yj,zj)使得xixjyiyjzizjxixjyiyjzizj,那么在第xi次插入之后逆序对数就会增加newi个(想一想,为什么?)。于是我们就得到了一个变形的三维偏序问题,我们需要想办法求出所有的newi
由于N达到100000,所以O(N2)的暴力是绝对炸的。网上有人讲解三维偏序问题时说了一句精辟的话:一维排序,二维分治,三维数据结构。按照这个思路,我们首先把所有点按x从小到大排序,重新标号为1~N,然后分治。这里使用的分治方法是CDQ分治,CDQ分治是一种思想,包含递归处理左半、处理左半对右半的影响、递归处理右半三个步骤。以下只考虑怎么处理左半对右半的影响。
假设我们在处理一个区间[l,r],这个区间的中点为mid,那么首先分别对于区间[l,mid][mid+1,r]y从小到大排序,因为x已经有序了我们就不管x,我们对于右半区间的点一个一个处理影响。因为两边的y都是有序的,那么我们就只需要在左边指一个只会往右的指针,设这个指针当前指到i,而右边我们正在处理的点为j,如果yiyj,那么就在计数数组里的zi位置增加1,然后i自增1,一直到i>mid或者yi>yj为止。然后我们再求计数数组中zj的所有位置之和,就可以得到左半区间对点j做出的贡献,将其加入newj即可。我们注意到对于计数数组的修改涉及单点修改和区间求和,这个我们可以用代码量小的树状数组解决。以上我们就处理完了一种情况,而另一种情况是类似的,这里就不再赘述了。
经过证明,以上方法的时间复杂度为O(Nlog2N),可以通过全部数据。注意每次处理完后清空树状数组时不要鲁莽地使用memset,会TLE,应该按照原来的顺序再把加上的东西都给减掉。除此之外,要注意排序和处理的顺序,因为有时排序会破坏掉原来的顺序。
以下是本人代码:

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define ll long long
using namespace std;
int n,m,pos[100010]={0};
ll ans[50010]={0},bit[100010]={0};
struct point3D
{
  int x,y,z,id;
}p[100010];

bool cmpx(point3D a,point3D b) {return a.x<b.x;}
bool cmpy1(point3D a,point3D b) {return a.y<b.y;}
bool cmpy2(point3D a,point3D b) {return a.y>b.y;}
bool cmpid(point3D a,point3D b) {return a.id<b.id;}

int lowbit(int x)
{
  return x&(-x);
}

void add(int x,ll d)
{
  for(int i=x;i<=n;i+=lowbit(i))
    bit[i]+=d;
}

ll query(int x)
{
  ll s=0;
  while(x)
  {
    s+=bit[x];
    x-=lowbit(x);
  }
  return s;
}

ll sum(int l,int r)
{
  return query(r)-query(l-1);
}

void solve(int l,int r)
{
  int mid=(l+r)>>1;
  if (l==r) return;
  solve(l,mid);

  int h;
  sort(p+l,p+mid+1,cmpy1);
  sort(p+mid+1,p+r+1,cmpy1);
  h=l;
  for(int i=mid+1;i<=r;i++)
  {
    while(h<=mid&&p[h].y<=p[i].y) add(p[h].z,1),h++;
    ans[m-p[i].x+1]+=sum(p[i].z,n);
  }
  for(int i=l;i<h;i++) add(p[i].z,-1);

  sort(p+l,p+mid+1,cmpy2);
  sort(p+mid+1,p+r+1,cmpy2);
  h=l;
  for(int i=mid+1;i<=r;i++)
  {
    while(h<=mid&&p[h].y>=p[i].y) add(p[h].z,1),h++;
    ans[m-p[i].x+1]+=sum(1,p[i].z);
  }
  for(int i=l;i<h;i++) add(p[i].z,-1);

  sort(p+l+1,p+r+1,cmpid);
  solve(mid+1,r);
}

int main()
{
  scanf("%d%d",&n,&m);
  for(int i=1;i<=n;i++)
  {
    p[i].y=i;
    scanf("%d",&p[i].z);
  }
  for(int i=1;i<=m;i++)
  {
    int a;
    scanf("%d",&a);
    pos[a]=m-i+1;
  }
  for(int i=1;i<=n;i++)
    p[i].x=pos[p[i].z];

  sort(p+1,p+n+1,cmpx);
  for(int i=1;i<=n;i++) p[i].id=i;
  solve(1,n);

  for(int i=m;i>=1;i--)
    ans[i]+=ans[i+1];
  for(int i=1;i<=m;i++)
    printf("%lld\n",ans[i]);

  return 0;
}
posted @ 2017-05-05 17:11  Maxwei_wzj  阅读(101)  评论(0编辑  收藏  举报