BZOJ 3295 转

树状数组套主席树,看的云里雾里,好不容易懂了一点。。。弱成渣。。。

#include<cstring>
#include<string>
#include<iostream>
#include<queue>
#include<cstdio>
#include<algorithm>
#include<map>
#include<cstdlib>
#include<cmath>
#include<vector>
//#pragma comment(linker, "/STACK:1024000000,1024000000");

using namespace std;

#define INF 0x3f3f3f3f
#define maxn 100005

int n,m,cnt;
long long sum;
long long ans[maxn];
bool vis[maxn];
int b[maxn];
int a[maxn],indx[maxn];
int root[maxn];

void init()
{
    memset(root,0,sizeof root);
    memset(vis,false,sizeof vis);
    cnt=0;
    sum=0;
}

struct node
{
    int l,r;
    int sum;
    node()
    {
        l=r=sum=0;
    }
} t[90*maxn];

void update(int &rt,int pre,int pos,int l,int r)
{
    if(!rt) t[rt=++cnt]=t[pre];
    t[rt].sum++;
    if(l==r&&r==pos)
    {
        t[rt].sum=1;
        return ;
    }
    int mid=l+r>>1;
    if(pos<=mid) update(t[rt].l,t[pre].l,pos,l,mid);
    else update(t[rt].r,t[pre].r,pos,mid+1,r);
}

int query1(int rt,int pos,int l,int r)
{
    if(l==r&&r==pos)
    {
        return 0;
    }
    int mid=l+r>>1;
    if(pos<=mid) return query1(t[rt].l,pos,l,mid)+t[t[rt].r].sum;
    else return query1(t[rt].r,pos,mid+1,r);
}

int query2(int rt,int pos,int l,int r)
{
    if(l==r&&r==pos)
    {
        return 0;
    }
    int mid=l+r>>1;
    if(pos<=mid) return query2(t[rt].l,pos,l,mid);
    else return query2(t[rt].r,pos,mid+1,r)+t[t[rt].l].sum;
}

int lowbit(int x)
{
    return x&-x;
}

void add(int x,int pos)
{
    for(; x<=n; x+=lowbit(x)) update(root[x],root[x],pos,1,n);
}

long long ask(int x,int pos)
{
    long long temp=0;
    for(int i=x; i>0; i-=lowbit(i)) temp+=query1(root[i],pos,1,n);
    for(int i=n; i>0; i-=lowbit(i)) temp+=query2(root[i],pos,1,n);
    for(int i=x; i>0; i-=lowbit(i)) temp-=query2(root[i],pos,1,n);
    return temp;
}

int main()
{
    scanf("%d%d",&n,&m);
    init();
    for(int i=1; i<=n; i++)
    {
        scanf("%d",&a[i]);
        indx[a[i]]=i;
    }
    for(int i=0; i<m; i++)
    {
        scanf("%d",&b[m-i-1]);
        vis[b[m-i-1]]=1;
    }
    for(int i=1; i<=n; i++)
    {
        if(!vis[a[i]])
        {
            add(i,a[i]);
            sum+=ask(i,a[i]);
        }
    }
    for(int i=0; i<m; i++)
    {
        add(indx[b[i]],b[i]);
        sum+=ask(indx[b[i]],b[i]);
        ans[i]=sum;
    }
    for(int i=m-1; i>=0; i--)
    {
        printf("%lld\n",ans[i]);
    }
    return 0;
}

 

posted on 2016-10-08 15:50  very_czy  阅读(119)  评论(0编辑  收藏  举报

导航