[算法笔记]分块算法从入门到TLE
分块算法在学习之前一直觉得是一个高端大气上档次,有着与众不同的O(√N)的时间复杂度。
(打公式真是太烦了,不过如果我不打公式zichen0535巨佬肯定又要嘲讽我。。。)
直到我阅读多方博客,才发现,这tm就是一个流氓算法。
对于区间问题,
我们把要处理的所有元素从左到右分成M个等长区间(以下称“块”)和最后一个比正常块略小的块,那么显然除了最后一个区间,每个区间的长度都是,最后一个区间长度是N%M。
可以证明M等于√N的时候,时间复杂度最低,这个分析完原理后再证明。
需要维护的信息有:
belong[x]:元素x所在的块的编号,样例代码中为bl[x];
start[x]:编号为x的块的最左边的点,样例代码中为st[x];
end[x]:编号为x的块的最右边的点,样例代码中为ed[x];
我们处理一个l到r的操作的时候,分类讨论一下
如果l和r在一个块中或者l和r在相邻块中,那么直接从l到r暴力处理一遍,处理的次数一定小于块的大小,也就是M。
如果l所在的块和r所在的块中间还有其它块,那么我们先暴力处理l到l所在块的右端,然后暴力处理r所在区间的左端到r,处理次数小于2*M。然后处理中间的块,因为块的处理非常方便,可以另外建立数组维护块的变化信息(如加上某个值),所以我们每个块只需要处理一次,那么我们处理次数少于次。
所以我们的复杂度分析就是O(M+)。根据基本不等式( )可知,在M=时取得最小值√N。解得M=√N时取得最小值。
所以我们在解决一般的题目时都会把所有元素分成√N个区间。
那么对于Q个操作,最终时间复杂度就是O(Q√N)。
这里给一道例题(奇水) 传送门
附上我的代码(略丑)
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #define maxn 100005 using namespace std; inline void read(int &x){ x=0;int f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} x*=f; } inline void read(long long &x){ x=0;int f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} x*=f; } int N,M,len,tot; int st[maxn],ed[maxn],bl[maxn]; long long add[maxn],arr[maxn],sum[maxn]; void init() { len=sqrt(N); tot=N/len; if(N%len) tot++; for(int i=1;i<=N;i++) { bl[i]=(i-1)/len+1; sum[bl[i]]+=arr[i]; } for(int i=1;i<=tot;i++) { st[i]=(i-1)*len+1; ed[i]=i*len; } } void update(int l,int r,long long k) { if(r<=ed[bl[l]]) { for(int i=l;i<=r;i++) { arr[i]+=k; sum[bl[i]]+=k; } return ; } for(int i=l;i<=ed[bl[l]];i++) { arr[i]+=k; sum[bl[i]]+=k; } for(int i=bl[l]+1;i<bl[r];i++) { sum[i]+=len*k; add[i]+=k; } for(int i=st[bl[r]];i<=r;i++) { arr[i]+=k; sum[bl[i]]+=k; } } long long query(int l,int r) { long long ans=0; if(r<=ed[bl[l]]) { for(int i=l;i<=r;i++) { ans+=arr[i]; } return ans; } for(int i=l;i<=ed[bl[l]];i++) ans+=arr[i]+add[bl[i]]; for(int i=bl[l]+1;i<bl[r];i++) ans+=sum[i]; for(int i=st[bl[r]];i<=r;i++) ans+=arr[i]+add[bl[i]]; return ans; } int main() { read(N);read(M); for(int i=1;i<=N;i++) read(arr[i]); init(); int op,l,r; long long k; while(M--) { read(op);read(l);read(r); if(op==1) { read(k); update(l,r,k); } else { printf("%lld\n",query(l,r)); } } }