[BZOJ5016]一个简单的询问

给你一个长度为N的序列ai,1≤i≤N和q组询问,每组询问读入l1,r1,l2,r2,需输出
 
get(l,r,x)表示计算区间[l,r]中,数字x出现了多少次。

Input

第一行,一个数字N,表示序列长度。
第二行,N个数字,表示a1~aN
第三行,一个数字Q,表示询问个数。
第4~Q+3行,每行四个数字l1,r1,l2,r2,表示询问。
N,Q≤50000
N1≤ai≤N
1≤l1≤r1≤N
1≤l2≤r2≤N
注意:答案有可能超过int的最大值

Output

对于每组询问,输出一行一个数字,表示答案

Sample Input

5
1 1 1 1 1
2
1 2 3 4
1 1 4 4

Sample Output

4
1

考虑转化前面的式子

get(l1,r1,x)*get(l2,r2,x)
=(get(1,r1,x)-get(1,l1-1,x))*(get(1,r2,x)-get(1,l2-1,x))
=get(1,r1,x)*get(1,r2,x)-get(1,l1-1,x)*get(1,r2,x)-get(1,r1,x)*get(1,l2-1,x)+get(1,l1-1,x)*get(1,l2-1,x)

考虑如何计算get(1,a,x)*get(1,b,x)

用两个指针l,r表示现在已经知道get(1,l,x)*get(1,r,x)的值(记在sum里),那我们要如何求get(1,l±1,x)*get(1,r±1,x)的值呢?

以求get(1,l+1,x)*(1,r,x)举例。

get(1,l+1,x)*get(1,r,x)

=get(1,l,x)*get(1,r,x)+a[l]在[1,r]里的出现次数

那我们考虑用cnt1[]表示[1,l]的每个数的出现次数,cnt2[]表示[1,r]的每个数的出现次数。

那么我们记sum=get(1,l,x)*get(1,r,x),则上式可化为sum+cnt2[a[r]],则可以在O(1)的时间复杂度内实现转移

我们把一个询问拆分成4个询问,即get(1,r1,x)*get(1,r2,x),get(1,l1-1,x)*get(1,r2,x),get(1,r1,x)*get(1,l2-1,x),get(1,l1-1,x)*get(1,l2-1,x),把每个询问用一个二元组表示[r1,r2],[l1-1,r2],[l2-1,r1],[l1-1,l2-1]然后按莫队的方法排序,对每个询问打一个标记ok表示它前面的符号是'+'还是'-'。

然后你可以用一个均值不等式啥的来求出最佳分块大小,为了简单起见,我们取分块大小T为√n,则时间复杂度为O(4n√n)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
struct xxx{
    int l,lblock,r,id,ok;
}q[201000];
int cnt1[50100],cnt2[50100],a[50100];long long ans[50100];
bool cmp(xxx a,xxx b){return a.lblock!=b.lblock?a.lblock<b.lblock:a.r<b.r;}
int main()
{
    int n;scanf("%d",&n);int T=(int)sqrt((double)n);
    for(int i=1;i<=n;i++)scanf("%d",&a[i]);
    int Q;scanf("%d",&Q);int tot=0;
    for(int i=1;i<=Q;i++)
    {
        int l1,r1,l2,r2;scanf("%d%d%d%d",&l1,&r1,&l2,&r2);
        ++tot;q[tot].l=r1;q[tot].r=r2;q[tot].id=i;q[tot].ok=1;q[tot].lblock=(r1)/T;
        ++tot;q[tot].l=l1-1;q[tot].r=r2;q[tot].id=i;q[tot].ok=-1;q[tot].lblock=(l1-1)/T;
        ++tot;q[tot].l=l2-1;q[tot].r=r1;q[tot].id=i;q[tot].ok=-1;q[tot].lblock=(l2-1)/T;
        ++tot;q[tot].l=l1-1;q[tot].r=l2-1;q[tot].id=i;q[tot].ok=1;q[tot].lblock=(l1-1)/T;
    }
    sort(q+1,q+tot+1,cmp);
    int l=0,r=0;long long sum=0;
    for(int i=1;i<=tot;i++)
    {
        while(l<q[i].l){cnt1[a[++l]]++;sum+=cnt2[a[l]];}
        while(l>q[i].l){cnt1[a[l]]--;sum-=cnt2[a[l--]];}
        while(r<q[i].r){cnt2[a[++r]]++;sum+=cnt1[a[r]];}
        while(r>q[i].r){cnt2[a[r]]--;sum-=cnt1[a[r--]];}
        ans[q[i].id]+=sum*q[i].ok;
    }
    for(int i=1;i<=Q;i++)printf("%lld\n",ans[i]);
    return 0;
}

 

posted @ 2017-11-20 20:07  lher  阅读(432)  评论(0编辑  收藏  举报