联赛模拟17_简单的区间
T1简单的区间
看到模数比较小,1e6的范围,可以开个数组,就有思路了
不同的区间max的位置不确定,我们考虑分治。对于每个区间,我们只计算跨过中点的区间贡献。
并且分2种情况,最大值在左边,最大值在右边。这样扫一边的时候,另一边的边界指针是单调的。
首先维护一个桶buc[]
我们以假设max在左边为例
假设此左指针 \(i\) 到 \(mid\) 之间的 \(sum\) 为 \(s1\)
\(mid+1\) 到右指针 \(j\) 之间的 \(sum\) 为 \(s2\)
对于枚举的左指针,右指针移动的同时用 \(buc[x]\),记录\(s2 mod k = x\) 的 \(s2\) 的个数
如果\((s1-max) mod k == y\) ,那么我们只需要找到有多少个\(s2\),与 \(s1-max\) 加起来能整除 \(k\) 即可(也就是找 \(buc[(k-y)%k]\) )
\(ma\) x在右边是一样的, \(buc[]\) 记录 \(s1\) ,对于 \(s2\) 去找合适的 \(s1\) 即可
注意:
1.\(max\) 在左边时,右指针 \(j\) 移动的判断 \(max[i]>=max[j]\), 那么避免重复,\(max\) 在右边,左指针 \(i\) 移动判断应为 \(max[i]<max[j]\) ,不再考虑取等
2.对于每个区间,\(buc[]\) 清空的时候不要 \(for(i=1->k)\) ,否则复杂度太高(2e6log),用个栈维护一下哪些该删就能保证复杂度为 \(O(nlogn)\)
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <cctype>
using namespace std;
char buf[1<<20],*p1,*p2;
#define rint register int
#define gc() (p1==p2?(p2=buf+fread(p1=buf,1,1<<20,stdin),p1==p2?EOF:*p1++):*p1++)
#define read() ({\
rint x=0;register bool f=0;register char ch=gc();\
while(!isdigit(ch)) f|=ch=='-',ch=gc();\
while(isdigit(ch)) x=(x<<3)+(x<<1)+(ch&15),ch=gc();\
f?-x:x;\
})
const int maxn=3e5+5;
int n,k;
int a[maxn];
int Max[maxn];
int sta[maxn],top;
int buc[1000000+5];
int sum[maxn];
long long ans;
void solve(rint l,rint r){
if(l==r) return;
rint mid=(l+r)/2;
solve(l,mid),solve(mid+1,r);
while(top) buc[sta[top--]]=0;
Max[mid]=a[mid],Max[mid+1]=a[mid+1];
for(rint i=mid-1;i>=l;--i) Max[i]=max(a[i],Max[i+1]);
for(rint i=mid+2;i<=r;++i) Max[i]=max(a[i],Max[i-1]);
rint now,i=mid,j=mid+1;
while(i>=l){
while(j<=r&&Max[j]<=Max[i]){ // 这里 <=
const rint res=(sum[j]-sum[mid]+k)%k;
buc[res]++;
sta[++top]=res;
++j;
}
now=((sum[mid]-sum[i-1]-Max[i])%k+k)%k;
ans+=buc[(k-now)%k]; // 一定要mod k,因为now会==0
--i;
}
while(top) buc[sta[top--]]=0;
i=mid,j=mid+1;
while(j<=r){
while(i>=l&&Max[i]<Max[j]){ //这里 < ,防止算重
const rint res=(sum[mid]-sum[i-1]+k)%k;
buc[res]++;
sta[++top]=res;
--i;
}
now=((sum[j]-sum[mid]-Max[j])%k+k)%k;
ans+=buc[(k-now)%k];
++j;
}
}
int main(){
freopen("interval.in","r",stdin);
freopen("interval.out","w",stdout);
n=read(),k=read();
for(rint i=1;i<=n;++i) sum[i]=(sum[i-1]+(a[i]=read()))%k;
solve(1,n);
printf("%lld\n",ans);
return 0;
}