[BZOJ5110]Yazid的新生舞会
题目大意:
给你一个长度为$n(n\leq 5\times 10^5)$的序列$A_{1\sim n}$。求满足区间众数在区间内出现次数严格大于$\lfloor\frac{r-l+1}{2}\rfloor$的区间$[l,r]$的个数。
思路:
分治。
对于一个区间$[l,r]$,设$mid=\lfloor\frac{l+r}{2}\rfloor$,我们可以求出所有经过$mid$的区间内能够成为众数的所有数。
不难发现所有的区间众数满足如下一个性质:如果$x$是区间$[l,r]$的众数,那么对于$l\leq x\leq r$,$x$一定是区间$[l,k]$或区间$(k,r]$的众数。
利用这一性质,我们可以令$k=mid$,这样就可以$O(n)$从$mid$出发往左右两边扫,求出能够成为众数的所有数。
接下来枚举每个众数$x$,求一下当前$[l,r]$区间中,以$x$作为众数的子区间个数。
具体我们可以先从$mid$往左扫,设往左扫到的端点为$b$,记录一下对于不同的$b$,$mid-b+1-cnt[x]$不同取值的出现次数。然后再往右扫,求出对于当前右端点$e$,求出满足$e-b+1-cnt[x]>\lfloor\frac{e-b+1}{2}\rfloor$的区间$[b,e]$的个数,这可以用前缀和快速求出。
这样我们就统计了区间$[l,r]$,经过$mid$的所有子区间。
对于不经过$mid$的子区间可以递归求解。
递归树中,每一层区间长度加起来是$n$,可能的众数个数有$\log n$个,每一层的时间复杂度是$O(n\log n)$。总共有$\log n$层,总的时间复杂度是$O(n\log^2 n)$。
1 #include<cstdio> 2 #include<cctype> 3 #include<algorithm> 4 typedef long long int64; 5 inline int getint() { 6 register char ch; 7 while(!isdigit(ch=getchar())); 8 register int x=ch^'0'; 9 while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0'); 10 return x; 11 } 12 const int N=500001; 13 int a[N],pos[N],num[N],cnt[N*2]; 14 int64 ans; 15 void solve(const int &l,const int &r) { 16 if(l==r) { 17 ans++; 18 return; 19 } 20 const int mid=(l+r)/2; 21 solve(l,mid); 22 solve(mid+1,r); 23 for(register int i=mid;i>=l;i--) { 24 if(++cnt[a[i]]>(mid-i+1)/2) { 25 if(!pos[a[i]]) { 26 num[pos[a[i]]=++num[0]]=a[i]; 27 } 28 } 29 } 30 for(register int i=mid+1;i<=r;i++) { 31 if(++cnt[a[i]]>(i-mid)/2) { 32 if(!pos[a[i]]) { 33 num[pos[a[i]]=++num[0]]=a[i]; 34 } 35 } 36 } 37 for(register int i=l;i<=r;i++) { 38 pos[a[i]]=cnt[a[i]]=0; 39 } 40 for(register int i=1;i<=num[0];i++) { 41 int sum=r-l+1,max=sum,min=sum; 42 cnt[sum]=1; 43 for(register int j=l;j<mid;j++) { 44 if(a[j]==num[i]) { 45 sum++; 46 } else { 47 sum--; 48 } 49 max=std::max(max,sum); 50 min=std::min(min,sum); 51 cnt[sum]++; 52 } 53 if(a[mid]==num[i]) { 54 sum++; 55 } else { 56 sum--; 57 } 58 for(register int i=min;i<=max;i++) { 59 cnt[i]+=cnt[i-1]; 60 } 61 for(register int j=mid+1;j<=r;j++) { 62 if(a[j]==num[i]) { 63 sum++; 64 } else { 65 sum--; 66 } 67 ans+=cnt[std::min(max,sum-1)]; 68 } 69 for(register int i=min;i<=max;i++) { 70 cnt[i]=0; 71 } 72 } 73 num[0]=0; 74 } 75 int main() { 76 const int n=getint(); getint(); 77 for(register int i=1;i<=n;i++) { 78 a[i]=getint(); 79 } 80 solve(1,n); 81 printf("%lld\n",ans); 82 return 0; 83 }