分块
首先我们要了解一个问题:为什么要用分块
我们拿一道题目举例:
例题
给定一个长度为 \(N\) 的数列 \(A\),以及 \(M\) 条指令,每条指令可能是以下两种之一:
C l r d
,表示把 \(A[l],A[l+1],…,A[r]\) 都加上 \(d\)。Q l r
,表示询问数列中第 \(l \sim r\) 个数的和。
对于每个询问,输出一个整数表示答案。
输入格式
第一行两个整数 \(N,M\)。
第二行 \(N\) 个整数 \(A[i]\)。
接下来 \(M\) 行表示 \(M\) 条指令,每条指令的格式如题目描述所示。
输出格式
对于每个询问,输出一个整数表示答案。
每个答案占一行。
数据范围
\(1 \le N,M \le 10^5\),
\(|d| \le 10000\),
\(|A[i]| \le 10^9\)
输入样例:
10 5
1 2 3 4 5 6 7 8 9 10
Q 4 4
Q 1 10
Q 2 4
C 3 6 3
Q 2 4
输出样例:
4
55
9
15
大家可能会用线段树或树状数组解决这题,但是它们的码量和思路……
段错误树写一大堆递归(还要用懒标记),树状数组思路难想,理解很难。
然后就是调代码(都是血与泪的教训)
分块本质上就是一个优化的暴力算法,它的思路通俗易懂,虽然时间复杂度不及那些高级数据结构,但仔细想想,线段树和树状数组处理每一步的时的效率为\(\log n\),分块的效率则为\(\sqrt n\),显然\(\sqrt n\)比\(\log n\)要处理的次数多上不少,但也是属于能过的状态,况且线段树因为巨大的常数有时还不如分块。
实测:
分块的思路
我们可以将一个序列分成\(\sqrt n\)块,如下图(当然最后一块不够的话就能分多少就分多少,不影响最终答案)
这里需要运用一点懒标记的思想,我们需要维护两个东西:
1、add:本段的所有数都要加上add
2、sum:本段的真实和是多少(算上add)
第一个操作:修改
修改的时候需要分两种情况:
1、完整段(蓝色部分):
\(add = add + d\)
\(sum = sum + d \times length\)(length为段的长度)
2、段内(红色部分):
直接暴力,枚举所有数
\(w_i = w_i + d\)
\(sum = sum + d\)
第二个操作:查询
同上,查询时也需要分两种情况:
1、完整段(蓝色部分):
直接累加sum即可
2、段内(红色部分):
直接暴力,枚举所有数,求和
其实如果在考试中出现这种有关序列的问题,如果实在想不出来正解,那就打个分块,一般能拿到70~80分
代码
#include <cmath>
#include <iostream>
using namespace std;
typedef long long ll;
const int N = 100010, M = 350;
int n, m, len;
ll add[M], sum[M];
int w[N];
int get(int i) { // 第i的位置映射到了哪一块
return i / len;
}
void change(int l, int r, int d) {
if (get(l) == get(r)) { // 段内,直接暴力
for (int i = l; i <= r; ++i)
w[i] += d, sum[get(i)] += d;
} else {
int i = l, j = r;
while (get(i) == get(l))
w[i] += d, sum[get(i)] += d, ++i;
while (get(j) == get(r))
w[j] += d, sum[get(j)] += d, --j;
for (int k = get(i); k <= get(j); ++k) {
sum[k] += len * d;
/*
这里有童鞋可能会问:如果最后一段有可能不足len长呢?
我们可以先将它补全,不影响最后的答案,因为题目是不会询问到数列以外的部分的
*/
add[k] += d;
}
}
}
ll query(int l, int r) {
ll res = 0;
if (get(l) == get(r)) { // 段内,直接暴力
for (int i = l; i <= r; ++i)
res += w[i] + add[get(i)];
} else {
int i = l, j = r;
while (get(i) == get(l))
res += w[i] + add[get(i)], ++i;
while (get(j) == get(r))
res += w[j] + add[get(j)], --j;
for (int k = get(i); k <= get(j); ++k)
res += sum[k];
}
return res;
}
int main() {
scanf("%d%d", &n, &m);
len = sqrt(n);
for (int i = 1; i <= n; ++i) {
scanf("%d", &w[i]);
sum[get(i)] += w[i];
}
char op[2];
int l, r, d;
while (m--) {
scanf("%s%d%d", op, &l, &r);
if (*op == 'C') {
scanf("%d", &d);
change(l, r, d);
} else
printf("%lld\n", query(l, r));
}
return 0;
}