分块
分块
分块算法实质上是一种是通过分成多块后在每块上打标记以实现快速区间修改,区间查询的一种算法。其均摊时间复杂度为 \(O(\sqrt n)\)
分块算法相较于各种树形数据结构,具有简便易写,方便调试等多种优点。在同等数据规模下,如 \(1e5\) ,其时间效率并不会低太多,在考试时反而是一种有力的得分方法。
接下来讲一下分块算法的基本操作及性质:
为了使得其有着最稳定的时间复杂度,我们经常讲一个长度为 \(n\) 的序列分为 \(\sqrt n\) 个大小为 \(\sqrt n\) 的块,如果 \(n\) 不是完全平方数,则序列最右端会多出一个角块
如下图,就是一种序列的分块:
该长度为 \(10\) 的序列被分为了 $ 4$ 块,前三块的大小为 $\sqrt 10 $ 的近似值 \(3\) ,最后一个角块大小为 \(1\)
而我们要记录的一个值,就是每个序号代表的数,属于哪一块
如上图,\(1,2,3\) 就属于第一块,\(4,5,6\) 就属于第二块,\(7,8,9\) 就属于第三块,\(10\) 就属于第四块
可以得到获取每一个序号的所在块的代码是:
int n;//总个数
int block=sqrt(n);//每一块大小
for(int i=1;i<=n;i++)
{
belong[i]=(i-1)/block+1;//每一个数所在块
}
但是,如何用分块来维护区间最值?
我们举个例子:
给定一个长的为 n
数列,求出任意区间 [l,r]
的最大值 (1<=l,r<=n)(l<=r)
还是拿这张图,我们现在给每个点上加了一个权值,每个块维护一下块内最大值
当我们查询任意一个区间 [l,r]
时,如果 l
所在的块与 r
所在的块相同,如 [1,2]
,则直接暴力查询即可,时间复杂度 \(O(\sqrt n)\) 若其不在一个块但是块是相邻的,一样是暴力查询,时间复杂度 \(O(\sqrt n)\) 若其块不相邻,如 \([1,10]\) ,我们先处理两边的边块角块,先暴力查询 \(1\) 和 \(10\) 所在的块内最大值,最后直接查询中间块内最大值即可,时间复杂度 \(O(\sqrt n)\) 所以总时间复杂度 \(O(\sqrt n)\) 那如果加入了区间修改,又该怎么办呢?
对于整块修改,我们打个加法标记,即当前块增加了多少,最大值相应的就增加了多少
而多于边块角块,暴力修改,特判最大值即可
所以总时间复杂度也是 \(O(\sqrt n)\) 分块还能解决很多很麻烦的问题,比如寻找区间内前驱后继
例题 POJ 3468
#include <bits/stdc++.h>
#define rint register int
#define endl '\n'
using namespace std;
const int N = 1e5 + 5;
const int M = 3e2 + 5e1 + 5;
int n, m, len;
long long add[M], sum[M];
int w[N];
int get(int i) //查询该数在第几段
{
return i / len;
}
inline int read()
{
int x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9')
{
if (ch == '-')
f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
x = x * 10 + ch - '0', ch = getchar();
return x * f;
}
void change(int l, int r, int d) //将 l~r 区间所有数 +d
{
if (get(l) == get(r)) // 段内直接暴力
{
for (rint i = l; i <= r; i++)
{
w[i] += d;
sum[get(i)] += d;
}
}
else
{
int i = l, j = r;
while (get(i) == get(l))
{
w[i] += d;
sum[get(i)] += d;
i++;
}
while (get(j) == get(r))
{
w[j] += d;
sum[get(j)] += d;
j--;
}
for (rint k = get(i); k <= get(j); k++)
{
sum[k] += len * d;
add[k] += d;
}
}
return;
}
long long query(int l, int r) // 区间和
{
long long res = 0;
if (get(l) == get(r)) // 段内直接暴力
{
for (rint i = l; i <= r; i++)
{
res += w[i] + add[get(i)];
}
}
else
{
int i = l, j = r;
while (get(i) == get(l))
{
res += w[i] + add[get(i)];
i++;
}
while (get(j) == get(r))
{
res += w[j] + add[get(j)];
j--;
}
for (rint k = get(i); k <= get(j); k++)
{
res += sum[k];
}
}
return res;
}
signed main()
{
n = read();
m = read();
len = sqrt(n);
for (rint i = 1; i <= n; i++)
{
w[i] = read();
sum[get(i)] += w[i];
}
char op[2];
int l, r, d;
while (m--)
{
cin >> op;
l = read();
r = read();
if (*op == 'C')
{
cin >> d;
change(l, r, d);
}
else
{
cout << query(l, r) << endl;
}
}
return 0;
}