「补题笔记」 Codeforces #668 E. Fixed Point Removal
PS:因为本题要用到线段树,所以数组都是从\(1\)开始存的。
题意
给定一组有\(n\)个数的数组\(a\),如果\(a_i = i\),那么你可以将\(a_i\)消除,然后\(a_i\)后面的数下标都\(-1\)。一共有\(q\)次询问,每次询问会给定一组\((x,y)\),问如果不能动前\(x\)个数和后\(y\)个数,最多可以消除多少个数?
思路
PS:没有特殊说明,下文提到的\(i\)都是指的数组下标。
首先,对于每一次询问,我们都可以得到我们能操作的区间——\(l = x + 1, r = n - y\),那么我们能操作的区间就是\([l, r]\)。
暴力
很容易想到暴力的思路:
我们从\(l\)开始枚举,如果遇到\(a_i = i\),那么\(ans := ans + 1\);如果遇到\(a_i > i\),那么这个数则不可能被消除;如果遇到\(a_i < i\),那么,如果此时的\(ans \geqslant i - a_i\),那么说明我们可以在消除前\(ans\)个数时,可以消除\(a_i\),否则\(a_i\)不可能被消除。
这样,我们的时间复杂度是\(O(nq)\)的。再看一下数据范围:\(1 \leq n,q \leq 3 \times 10^5\),稳TLE。
改进
本题有一个很麻烦的地方就在于,它会限制前\(x\)个数不能动。如果前\(x\)个数不动,后面就可能有一些元素,从\(1\)开始消除,它们是能消除的,但是从\(x + 1\)开始,它们就不能消除了。
也就是说,对于每一个可能被消除的元素\(a_i\),它都存在一个边界\(j\),当\(x \leqslant j\)时,\(a_i\)是能被消除的,当\(x > j\)时,它就消除不了了。对于每个\(i\),我们称这个\(j\)为$ l[i]\(。那么,换句话说,当\)l \leqslant l[i]$时,这个元素能被消除;否则不能被消除。
那么,这个时候,显然答案就变成了,对于每一次询问的\(l, r\),我们看\([l,r]\)中有多少个\(l[i]\)落在区间\([l,r]\)上。
那么,对于这个查询,我们就可以考虑维护一个线段树,枚举每个\(i\),它的\(l[i]\)位置上的\(val +1\)即可,然后查询就直接用线段树区间查询\([l,r]\)就是答案。
PS:这里会有个问题,就是如果我先一次性把从\(1\)到\(n\)的所有\(l[i]\)都加进线段树里,那么我查询\([l,r]\)的时候,可能会有\(i > r\),但是\(l \leqslant l[i] \leqslant r\)的情况,此时会把下标大于\(r\)的数也算进去,但是显然我们的答案不能算进去这一部分。
要解决这个问题,我们就一次性把所有查询读进来离线处理。我们把所有的查询按照\(r\)从小到大排序,然后处理到第\(i\)个的时候,就把所有\(i \leqslant query[j].r\)的\(a_i\)放进线段树里。这样就能避免\(i > r\)那一部分元素的影响了。
求\(l[i]\)
这是最后一个麻烦事。我们考虑\(l[i]\)的含义:当\(l \leqslant l[i]\)时,这个元素能被消除;否则不能被消除。
所以,我们有:
- 当\(a_i = i\)时,显然,\(l[i] = i\);
- 当\(a_i > i\)时,显然,无论如何,我们都无法消除\(a_i\),此时我们给\(l[i]\)赋值为\(-1\);
- 当\(a_i < i\)时,\(a_i\)有可能会被消除,也有可能无法被消除:
- 当\(i\)前面所有能消除的数都被消除以后,\(a_i\)仍然小于\(i\),那么此时,\(a_i\)无法被消除,我们给\(l[i]\)赋值为\(-1\);
- 当\(i\)前面的\(i - a_i\)个数被消除以后,\(a_i\)就能被消除了。此时,我们消除最靠近\(i\)的\(i - a_i\)个数显然是最优的。
前面的都很简单,我们现在只考虑第三点:如何考虑消除这\(i - a_i\)个数?
首先,我们是从\(1\)到\(n\)枚举的。这保证了线段树上所有\(+1\)的点,都是\(i\)前面的数造成的;
然后,对于每个\(l[i]\),我们在线段数上在\(l[i]\)这个位置\(+1\),这保证了一个能被消除的点的贡献一定是\(1\);
所以,我们要考虑消除最靠近\(i\)的\(i - a_i\)个点,就可以变成,我们找能被消除的第\(num - (i - a_i) + 1\)个点。这个点,就是\(l[i]\)。
所以,对于第三点,找\(l[i]\)就变成了找一个\(j\),满足\(sum(1,j) = num - (i - a_i) + 1\)。(也可以理解成找一个\(j\),满足\(sum(j, i - 1) = i - a_i\))。其中\(num\)是从\(1\)到\(i - 1\)能被消除的数的总数。这个操作可以在线段树上\(O(\log n)\)完成,代码如下:
int findloc(int rt, int l, int r, const int &k) {
if (l == r)
return l;
int mid = (l + r) >> 1;
if (seg[rt << 1] >= k)
return findloc(rt << 1, l, mid, k);
else
return findloc(rt << 1 | 1, mid + 1, r, k - seg[rt << 1]);
}
其中的\(k\)就是上面说到的\(j\)。
代码
为了方便处理,我在代码中一开始就直接把所有的\(a_i\)替换成了\(i - a_i\)。线段树可以重复使用,不过重复使用前不要忘记build
一波初始化。
#include <cstdio>
#include <cstring>
#include <algorithm>
const int maxn = 3e5 + 5;
using namespace std;
struct Triple {
int l, r, id, ans;
}t[maxn];
int a[maxn], n, q, f[maxn], l[maxn];
int seg[maxn << 2];
inline int cmp(const Triple &a, const Triple &b) { // 按照r排序
return a.r < b.r;
}
inline int cmp2(const Triple &a, const Triple &b) { // 按照id排序,处理完了最后输出
return a.id < b.id;
}
/***********线段树***********/
inline void pushup(int rt) {
seg[rt] = seg[rt << 1] + seg[rt << 1 | 1];
}
// 用于初始化
void build(int rt, int l, int r) {
if (l == r) {
seg[rt] = 0;
return ;
}
int mid = (l + r) >> 1;
build(rt << 1, l, mid);
build(rt << 1 | 1, mid + 1, r);
pushup(rt);
}
// 单点修改
void update(int rt, int l, int r, const int &loc, const int &val) {
if (l == r) {
seg[rt] += val;
return ;
}
int mid = (l + r) >> 1;
if (loc <= mid) update(rt << 1, l, mid, loc, val);
else if (loc > mid) update(rt << 1 | 1, mid + 1, r, loc, val);
pushup(rt);
}
// 查询sum[1, r] == k的r的位置
int findloc(int rt, int l, int r, const int &k) {
if (l == r)
return l;
int mid = (l + r) >> 1;
if (seg[rt << 1] >= k)
return findloc(rt << 1, l, mid, k);
else
return findloc(rt << 1 | 1, mid + 1, r, k - seg[rt << 1]);
}
// 区间查询求和
int query(int rt, int l, int r, const int &L, const int &R) {
if (L <= l && r <= R) {
return seg[rt];
}
int mid = (l + r) >> 1;
int res = 0;
if (L <= mid) res += query(rt << 1, l, mid, L, R);
if (R > mid) res += query(rt << 1 | 1, mid + 1, r, L, R);
pushup(rt);
return res;
}
/***********线段树***********/
int main() {
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++)
scanf("%d", a + i);
for (int i = 1; i <= n; i++)
a[i] = i - a[i];
build(1, 1, n);
int num = 0;
for (int i = 1; i <= n; i++) {
if (a[i] < 0) l[i] = -1;
else if (a[i] == 0) {
l[i] = i;
num++;
update(1, 1, n, i, 1);
}
else {
int tmp = num - a[i] + 1;
if (tmp <= 0) l[i] = -1;
else {
l[i] = findloc(1, 1, n, tmp);
update(1, 1, n, l[i], 1);
num++;
}
}
}
for (int i = 0; i < q; i++) {
scanf("%d%d", &t[i].l, &t[i].r);
t[i].l++;
t[i].r = n - t[i].r;
t[i].id = i;
}
sort(t, t + q, cmp);
int cur = 1;
build(1, 1, n);
for (int i = 0; i < q; i++) {
while (cur <= t[i].r) {
if (l[cur] > 0)
update(1, 1, n, l[cur], 1);
cur++;
}
t[i].ans = query(1, 1, n, t[i].l, t[i].r);
}
sort(t, t + q, cmp2);
for (int i = 0; i < q; i++)
printf("%d\n", t[i].ans);
return 0;
}