Sparse Table
Sparse Table 可用于解决这样的问题:给出一个 \(n\) 个元素的数组 \(a_1, a_2, \cdots, a_n\),支持查询操作计算区间 \([l,r]\) 的最小值(或最大值)。这种问题被称为区间最值查询问题(Range Minimum/Maximum Query,简称 RMQ 问题)。预处理的时间复杂度为 \(O(n \log n)\),预处理后数组 \(a\) 的值不可以修改,一次查询操作的时间复杂度为 \(O(1)\)。
例题:P2880 [USACO07JAN] Balanced Lineup G
有一个包含 \(n\) 个数的序列 \(h_i\),有 \(q\) 次询问,每次询问 \(h_{a_i}, h_{a_i + 1}, \cdots, h_{b_i - 1}, h_{b_i}\) 中最大值与最小值的差。
数据范围:\(1 \le n \le 5 \times 10^4, \ 1 \le q \le 1.8 \times 10^5, \ 1 \le h_i \le 10^6, \ a_i \le b_i\)。
分析:题目要求最大值和最小值的差难以直接求出,通常需要分别求解最大值和最小值。最直接的做法是每次遍历区间中的每一个数,记录最大值和最小值。这样可以正确求出正确答案,但是效率低下,时间复杂度高达 \(O(nq)\),无法通过本题。
之所以这样做效率低下,是因为所有询问区间可能有着大量的重叠,这些重叠部分被多次遍历到,因此产生了大量的重复。如果可以通过预处理得到一些区间的最小值,再通过这些区间拼凑每一个询问区间,就可以提高效率。
预处理前缀和可以拼凑出任意区间的和,但是这个思路不能直接搬到最值查询问题中。原因在于区间和可以从一个大区间中减去一部分小区间得到,而区间最值不行,所以只能用小区间去拼出大区间。如何选择预处理的区间就成为关键,选择的区间既要能够拼出任意区间,数量少又不能太多,并且预处理和查询都要高效。
可以预处理以每一个位置为开头,长度为 \(2^0, 2^1, \cdots, 2^{\lfloor \log_2 n \rfloor}\) 的所有区间最值。下面以最大值为例,用 \(f_{i,j}\) 表示 \(h_i, h_{i+1}, \cdots, h_{i+2^j-2}, h_{i+2^j-1}\) 中的最大值,用递推的方式计算所有的 \(f\),转移为 \(f_{i,j} = \max (f_{i,j-1}, f_{i+2^{j-1}, j-1})\)。计算所有的 \(f\) 的过程为预处理,预处理的时间复杂度为 \(O(n \log n)\)。
void init() {
for (int i = 1; i <= n; i++) {
f[i][0] = h[i];
}
for (int j = 1; (1 << j) <= n; j++) {
for (int i = 1; i <= n - (1 << j) + 1; i++) {
f[i][j] = max(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
}
}
接下来解决查询的问题,设需要查询最大值的区间是 \([l,r]\)。记区间长度为 \(L\),则该区间可以拆分为 \(O(\log L)\) 个小区间。对 \(L\) 做二进制拆分,从 \(l\) 开始向后跳,每次跳跃的量是一个 \(2\) 的幂,从而拼出整个区间。单词查询时间复杂度为 \(O(\log n)\)。
int query(int l, int r) {
int len = r - l + 1, ans = -INF, cur = l;
for (int i = 0; (1 << i) <= len; i++) {
if ((len >> i) & 1) {
ans = max(ans, f[cur][i]);
cur += (1 << i);
}
}
return ans;
}
更进一步,查询区间最值时,区间合并的过程允许重叠,因此只需要找到两个长度为 \(2^k\) 的区间合并得到 \([l,r]\)。令 \(k\) 为满足 \(2^k \le r-l+1\) 的最大整数,区间 \([l, l+2^k-1]\) 和区间 \([r-2^k+1,r]\) 合并起来覆盖了需要查询的区间 \([l,r]\)。
int query(int l, int r) {
int k = log_2[r - l + 1]; // 可以预处理log_2的表
return max(f[l][k], f[r - (1 << k) + 1][k]);
}
参考代码
#include <cstdio>
#include <algorithm>
using std::min;
using std::max;
const int N = 50005;
const int LOG = 16;
int h[N], f_min[N][LOG], f_max[N][LOG], log_2[N];
void init(int n) {
log_2[1] = 0;
for (int i = 2; i <= n; i++) log_2[i] = log_2[i >> 1] + 1; // 预处理对数表
for (int i = 1; i <= n; i++) {
f_min[i][0] = f_max[i][0] = h[i];
}
for (int j = 1; (1 << j) <= n; j++) {
for (int i = 1; i <= n - (1 << j) + 1; i++) {
f_min[i][j] = min(f_min[i][j - 1], f_min[i + (1 << (j - 1))][j - 1]);
f_max[i][j] = max(f_max[i][j - 1], f_max[i + (1 << (j - 1))][j - 1]);
}
}
}
int query(int l, int r, int flag) { // flag为1时查询最大值,为0时查询最小值
int k = log_2[r - l + 1];
if (flag) return max(f_max[l][k], f_max[r - (1 << k) + 1][k]);
else return min(f_min[l][k], f_min[r - (1 << k) + 1][k]);
}
int main()
{
int n, q; scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++) scanf("%d", &h[i]);
init(n);
for (int i = 1; i <= q; i++) {
int a, b; scanf("%d%d", &a, &b);
printf("%d\n", query(a, b, 1) - query(a, b, 0));
}
return 0;
}
Sparse Table 预处理部分的时间复杂度为 \(O(n \log n)\),查询一次的时间复杂度为 \(O(1)\),总的时间复杂度为 \(O(n \log n)\)。
Sparse Table 不仅可以求区间最大值和最小值,还可以处理符合结合律和幂等律(与自身做运算,结果仍是自身)的信息查询,如区间最大公约数、区间最小公倍数、区间按位或、区间按位与等。
例题:P7333 [JRKSJ R1] JFCA
分析:看到环形,先破环成链。
看起来每个点的答案 \(O(1)\) 求得不太容易,但每个点的答案具有二分性。
对于每个点,二分答案,查询左右两段区间的最大值看是否大于等于 \(b_i\)。
因为没有修改,所以区间最值可以用 Sparse Table 维护,总体时间复杂度为 \(O(n \log n)\)。
参考代码
#include <cstdio>
#include <algorithm>
using std::max;
using std::min;
const int N = 300005;
const int LOG = 17;
int a[N], b[N], log_2[N], st[N][LOG], ans[N];
int query(int l, int r) {
int len = log_2[r - l + 1];
return max(st[l][len], st[r - (1 << len) + 1][len]);
}
int main()
{
int n; scanf("%d", &n);
for (int i = 2; i <= n; i++) log_2[i] = log_2[i / 2] + 1;
for (int i = 1; i <= n; i++) {
// 破环成链
scanf("%d", &a[i]); a[i + n * 2] = a[i + n] = a[i];
st[i][0] = st[i + n][0] = st[i + n * 2][0] = a[i];
ans[i] = n;
}
// Sparse Table 维护区间最大值
for (int j = 1; (1 << j) <= n; j++) {
for (int i = 1; i <= 3 * n - (1 << j) + 1; i++) {
st[i][j] = max(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
}
}
for (int i = 1; i <= n; i++) scanf("%d", &b[i]);
for (int i = n + 1; i <= n * 2; i++) {
// left 二分左边第一个位置
int l = i - n + 1, r = i - 1, res = -1;
while (l <= r) {
int mid = (l + r) / 2;
if (query(mid, i - 1) >= b[i - n]) {
l = mid + 1; res = mid;
} else {
r = mid - 1;
}
}
if (res != -1) ans[i - n] = min(ans[i - n], i - res);
// right 二分右边第一个位置
l = i + 1; r = i + n - 1; res = -1;
while (l <= r) {
int mid = (l + r) / 2;
if (query(i + 1, mid) >= b[i - n]) {
r = mid - 1; res = mid;
} else {
l = mid + 1;
}
}
if (res != -1) ans[i - n] = min(ans[i - n], res - i);
}
for (int i = 1; i <= n; i++) printf("%d ", ans[i] == n ? -1 : ans[i]);
return 0;
}