【题解】P4137 - Rmq Problem / mex
前言
其实这道题更像是先手玩一下如何处理 \(mex\),然后根据性质来选择使用 线段树 来维护。这道题还可以用 值域分块 和 莫队 来做,值域线段树 \(+\) 线段树的做法大概是码量最大的做法了(。
题目大意
给定一个长度为 \(n\) 的数组 \(a\) 和 \(m\) 次询问,每次询问区间 \([l, r]\) 的 \(mex\),\(mex\) 定义为一个区间内 最小的 没有出现过的 自然数。
\(1 \leq n, m \leq 2 \times 10^5\) ,\(1 \leq l \leq r \leq n, 0 \leq a_i \leq 10^9\) 。
解题思路
这道题的关键还是在于 \(mex\) 的 性质 :对于一个长度为 \(n\) 的区间,其 \(mex\) 不可能超过 \(n\) 。下面给出证明:使 \(mex\) 最大的构造方案为 \(0\) 到 \(n - 1\) ,如果将其中任意一个数 \(x\) 替换为 \(k\) 且 \(k \geq n\) 。因为 \(x \leq n - 1 < k\) ,此时区间的 \(mex\) 为 \(x\) ,\(x < n\) 。我们可以把权值 \(\geq n\) 的值全部替换成 \(n\) 。
但是这样还不是很好维护 \(mex\) 。我们可以考虑把询问 离线处理 ,按 \(l\) 从小到大排序。我们先 预处理 原区间的 \(mex\) 前缀,也就是区间 \([1, 1]\) 的 \(mex\),区间 \([1, 2]\) 的 \(mex\) ,区间 \([1, 3]\) 的 \(mex\) ……然后我们把这些 \(mex\) 放在一棵线段树的叶子结点上。这棵线段树中只有叶子结点存储的信息有用,因此不需要维护非叶子结点的信息,也就是不用 push_up
。为了方便下文提到的权值线段树维护,我们令 \(a_i\) 等于 \(a_i + 1\),相应地,\(mex\) 最大为 \(n + 1\)。
对于左端点为 \(1\) 的询问,我们可以直接返回相应叶子结点的权值。但是当左端点右移时,我们必须考虑到 \([2, l - 1]\) 对于答案的影响。假设右移前的左端点为 \(l\) ,权值 \(a_l\) 下一次出现的位置在 \(nxt\) 。在区间 \([l + 1, nxt - 1]\) 中,权值 \(a_l\) 一定没有出现过。因此,当左端点右移时,它对剩余区间的影响就是区间 \([l + 1, nxt - 1]\) 多了一个可能的 \(mex\) 取值,给线段树上这个区间取最小值即可。
这样维护,每一次线段树单点查询的时候都考虑到了前面区间的所有影响,故而答案是正确的。
难点在于如何在 \(O(nlogn)\) 的时间复杂度内预处理 \(mex\) 前缀。我们可以再开一棵权值线段树,维护每个值出现的次数以及区间内 最小的 权值出现次数。处理到第 \(i\) 个位置时,我们在权值线段树第 \(i\) 个叶子结点的位置 \(+ 1\) 。每次查询时从权值线段树的根结点开始遍历,如果左子树的最小出现次数为 \(0\) ,说明左子树还有权值未出现过。因为左子树代表的权值比右子树小,进入左子树递归查找。否则,进入右子树查找。最坏情况下 \(1\) 到 \(n\) 全部出现过,此时会找到右下角的叶子结点,也就是 \(n + 1\) 。
最后,对于 \(nxt\) 的维护,我们可以开一个 vector[maxn]
来存储,vector[i]
存储权值 \(i\) 出现过的所有下标。另外用一个数组 idx[i]
表示 vector[i]
使用到了第 idx[i]
位。每次处理到位置 \(i\) 时令 idx[i]++
即可。初始时 idx[i] = 1
。还要在所有的 vector[i]
最后插入一个 \(n + 1\) 来维护在线段树上操作区间后缀的情况。
总而言之,这道题的线段树做法细节很多,如果读者在代码过程中遇到困难,可以自行斟酌是否对照题解代码(我觉得您可以适当抄一抄) 。
参考代码
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
const int maxn = 2e5 + 5;
const int maxm = 2e5 + 5;
const int inf = 0x3f3f3f3f;
struct ques {
int l, r, id;
} q[maxm];
struct val_tree {
struct node {
int l, r, num, val, minv;
} tree[maxn << 2];
void push_up(int k) {
tree[k].minv = min(tree[2 * k].minv, tree[2 * k + 1].minv);
}
void build(int k, int l, int r) {
tree[k].l = l;
tree[k].r = r;
if (l == r) {
tree[k].num = l;
tree[k].val = tree[k].minv = 0;
return;
}
int mid = (l + r) / 2;
build(2 * k, l, mid);
build(2 * k + 1, mid + 1, r);
push_up(k);
}
void update(int k, int x) {
if (tree[k].l == tree[k].r) {
tree[k].val++;
tree[k].minv++;
return;
}
int mid = (tree[k].l + tree[k].r) / 2;
if (x <= mid) {
update(2 * k, x);
} else {
update(2 * k + 1, x);
}
push_up(k);
}
int query(int k) {
if (tree[k].l == tree[k].r) {
return tree[k].num;
} else if (!tree[2 * k].minv) {
return query(2 * k);
} else {
return query(2 * k + 1);
}
}
} valt;
struct seg_tree {
struct node {
int l, r, val, lazy;
} tree[maxn << 2];
void push_down(int k) {
if (tree[k].l == tree[k].r) {
tree[k].lazy = inf;
return;
}
tree[2 * k].val = min(tree[2 * k].val, tree[k].lazy);
tree[2 * k + 1].val = min(tree[2 * k + 1].val, tree[k].lazy);
tree[2 * k].lazy = min(tree[2 * k].lazy, tree[k].lazy);
tree[2 * k + 1].lazy = min(tree[2 * k + 1].lazy, tree[k].lazy);
tree[k].lazy = inf;
}
void build(int k, int l, int r) {
tree[k].l = l;
tree[k].r = r;
tree[k].lazy = inf;
if (l == r) {
tree[k].val = inf;
return;
}
int mid = (l + r) / 2;
build(2 * k, l, mid);
build(2 * k + 1, mid + 1, r);
}
void update(int k, int l, int r, int x) {
if (tree[k].l >= l && tree[k].r <= r) {
tree[k].val = min(tree[k].val, x);
tree[k].lazy = min(tree[k].lazy, x);
return;
}
push_down(k);
int mid = (tree[k].l + tree[k].r) / 2;
if (l <= mid) {
update(2 * k, l, r, x);
}
if (r > mid) {
update(2 * k + 1, l, r, x);
}
}
int query(int k, int x) {
if (tree[k].l == tree[k].r) {
return tree[k].val;
}
push_down(k);
int mid = (tree[k].l + tree[k].r) / 2;
if (x <= mid) {
return query(2 * k, x);
} else {
return query(2 * k + 1, x);
}
}
} segt;
int n, m;
int a[maxn], idx[maxn], ans[maxn];
vector<int> nxt[maxn];
bool cmp(ques a, ques b) {
return a.l < b.l;
}
int main() {
int val, p = 1;
scanf("%d%d", &n, &m);
valt.build(1, 1, n + 1);
segt.build(1, 1, n);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
a[i] = min(a[i] + 1, n + 1);
nxt[a[i]].push_back(i);
}
for (int i = 1; i <= n + 1; i++) {
nxt[i].push_back(n + 1);
idx[i] = 1;
}
for (int i = 1; i <= m; i++) {
scanf("%d%d", &q[i].l, &q[i].r);
q[i].id = i;
}
sort(q + 1, q + m + 1, cmp);
for (int i = 1; i <= n; i++) {
valt.update(1, a[i]);
val = valt.query(1);
segt.update(1, i, i, val);
}
for (int i = 1; i <= m; i++) {
while (p < q[i].l) {
segt.update(1, p + 1, nxt[a[p]][idx[a[p]]] - 1, a[p]);
idx[a[p]]++, p++;
}
ans[q[i].id] = segt.query(1, q[i].r) - 1;
}
for (int i = 1; i <= m; i++) {
printf("%d\n", ans[i]);
}
return 0;
}