整体二分
整体二分
一类题目可以二分解决,但是有多次询问,于是考虑整体二分,主要思想为把多个查询一起解决,是一个离线算法。
使用条件:
- 答案可以二分求得。
- 允许离线。
- 修改对判定答案的贡献互相独立,修改之间互相独立。
- 修改如果对判定答案有贡献,则贡献为一与判定标准无关的定值。
实现
首先把所有操作按时间顺序存入数组中,然后开始分治。
整体二分函数 solve(l, r, L, R)
表示操作 \([L, R]\) 的答案在 \([l, r]\) 中。
若 \(l = r\) ,则说明找到答案。否则在每一层分治中,利用数据结构统计当前查询的答案和 \(mid = \dfrac{l + r}{2}\) 之间的关系,将当前处理的操作序列分为两份并分别递归处理。
需要注意的是,在整体二分过程中,若当前处理的值域为 \([l, r]\) ,则此时最终答案范围不在 \([l, r]\) 的询问会在其他时候处理。
如果分治中用线性结构维护,时间复杂度 \(O(n \log V)\) 。
应用
求解k小值
给出 \(a_{1 \sim n}\) ,\(m\) 次询问 \(a_{l \sim r}\) 的 \(k\) 小值。
\(n, m \leq 2 \times 10^5\)
处理当前层时只要用树状数组维护值与 \(mid\) 的关系即可。
#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 2e5 + 7;
struct Node {
int l, r, k, id;
} nd[N << 1], tmp1[N << 1], tmp2[N << 1];
int ans[N];
int n, m;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
namespace BIT {
int c[N];
inline void update(int x, int k) {
for (; x <= n; x += x & -x)
c[x] += k;
}
inline int query(int x) {
int res = 0;
for (; x; x -= x & -x)
res += c[x];
return res;
}
} // namespace BIT
void solve(int l, int r, int L, int R) {
if (L > R)
return;
if (l == r) {
for (int i = L; i <= R; ++i)
if (nd[i].id)
ans[nd[i].id] = l;
return;
}
int mid = (l + r) >> 1, lp = 0, rp = 0;
for (int i = L; i <= R; ++i)
if (nd[i].id) {
int x = BIT::query(nd[i].r) - BIT::query(nd[i].l - 1);
if (nd[i].k <= x)
tmp1[lp++] = nd[i];
else
nd[i].k -= x, tmp2[rp++] = nd[i];
} else {
if (nd[i].k <= mid)
tmp1[lp++] = nd[i], BIT::update(nd[i].l, 1);
else
tmp2[rp++] = nd[i];
}
for (int i = 0; i < lp; ++i)
if (!tmp1[i].id && tmp1[i].k <= mid)
BIT::update(tmp1[i].l, -1);
memcpy(nd + L, tmp1, sizeof(Node) * lp);
memcpy(nd + L + lp, tmp2, sizeof(Node) * rp);
solve(l, mid, L, L + lp - 1), solve(mid + 1, r, L + lp, R);
}
signed main() {
n = read(), m = read();
for (int i = 1; i <= n; ++i)
nd[i].l = i, nd[i].k = read();
for (int i = 1; i <= m; ++i)
nd[n + i].l = read(), nd[n + i].r = read(), nd[n + i].k = read(), nd[n + i].id = i;
solve(-inf, inf, 1, n + m);
for (int i = 1; i <= m; ++i)
printf("%d\n", ans[i]);
return 0;
}
给出 \(a_{1 \sim n}\) ,\(m\) 次操作:
- 修改 \(a_x\) 为 \(k\) 。
- 询问 \(a_{l \sim r}\) 的 \(k\) 小值。
\(n, m \leq 10^5\)
修改就是把原来的删掉在加上新的值而已。
#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 1e5 + 7;
struct Node {
int op, l, r, k, id;
} nd[N * 3], tmp1[N * 3], tmp2[N * 3];
int a[N], ans[N];
int n, m, tot, cntq;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
inline char readc() {
char c = getchar();
while (c != 'Q' && c != 'C')
c = getchar();
return c;
}
namespace BIT {
int c[N];
inline void update(int x, int k) {
for (; x <= n; x += x & -x)
c[x] += k;
}
inline int query(int x) {
int res = 0;
for (; x; x -= x & -x)
res += c[x];
return res;
}
} // namespace BIT
void solve(int l, int r, int L, int R) {
if (L > R)
return;
if (l == r) {
for (int i = L; i <= R; ++i)
if (nd[i].id)
ans[nd[i].id] = l;
return;
}
int mid = (l + r) >> 1, lp = 0, rp = 0;
for (int i = L; i <= R; ++i)
if (nd[i].op == 1) {
if (abs(nd[i].k) <= mid)
tmp1[lp++] = nd[i], BIT::update(nd[i].l, 1);
else
tmp2[rp++] = nd[i];
} else if (nd[i].op == 2) {
if (abs(nd[i].k) <= mid)
tmp1[lp++] = nd[i], BIT::update(nd[i].l, -1);
else
tmp2[rp++] = nd[i];
} else {
int x = BIT::query(nd[i].r) - BIT::query(nd[i].l - 1);
if (nd[i].k <= x)
tmp1[lp++] = nd[i];
else
nd[i].k -= x, tmp2[rp++] = nd[i];
}
for (int i = 0; i < lp; ++i)
if (tmp1[i].op == 1 && tmp1[i].k <= mid)
BIT::update(tmp1[i].l, -1);
else if (tmp1[i].op == 2 && tmp1[i].k <= mid)
BIT::update(tmp1[i].l, 1);
memcpy(nd + L, tmp1, sizeof(Node) * lp);
memcpy(nd + L + lp, tmp2, sizeof(Node) * rp);
solve(l, mid, L, L + lp - 1), solve(mid + 1, r, L + lp, R);
}
signed main() {
n = read(), m = read();
for (int i = 1; i <= n; ++i)
nd[++tot] = (Node) {1, i, 0, a[i] = read(), 0};
for (int i = 1; i <= m; ++i) {
if (readc() == 'C') {
int x = read(), k = read();
nd[++tot] = (Node) {2, x, 0, a[x], 0};
nd[++tot] = (Node) {1, x, 0, a[x] = k, 0};
} else {
int l = read(), r = read(), k = read();
nd[++tot] = (Node) {3, l, r, k, ++cntq};
}
}
solve(-inf, inf, 1, tot);
for (int i = 1; i <= cntq; ++i)
printf("%d\n", ans[i]);
return 0;
}
维护 \(n\) 个可重集,初始均为空。\(m\) 次操作:
- 将 \(k\) 加入到编号在 \([l, r]\) 内的集合中。
- 查询编号在 \([l, r]\) 内的集合的并集的第 \(k\) 大值。
注意可重集的并是不去除重复元素的。
\(n, m \leq 5 \times 10^4\)
用线段树维护区间加、区间和即可。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 5e4 + 7;
struct Node {
ll k;
int l, r, id;
} nd[N], ndl[N], ndr[N];
int ans[N];
int n, m, cntq;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
namespace SMT {
ll s[N << 2];
int tag[N << 2];
inline int ls(int x) {
return x << 1;
}
inline int rs(int x) {
return x << 1 | 1;
}
inline void pushup(int x) {
s[x] = s[ls(x)] + s[rs(x)];
}
inline void spread(int x, int l, int r, int k) {
s[x] += 1ll * (r - l + 1) * k, tag[x] += k;
}
inline void pushdown(int x, int l, int r) {
if (tag[x]) {
int mid = (l + r) >> 1;
spread(ls(x), l, mid, tag[x]);
spread(rs(x), mid + 1, r, tag[x]);
tag[x] = 0;
}
}
void update(int x, int nl, int nr, int l, int r, int k) {
if (l <= nl && nr <= r) {
spread(x, nl, nr, k);
return;
}
pushdown(x, nl, nr);
int mid = (nl + nr) >> 1;
if (l <= mid)
update(ls(x), nl, mid, l, r, k);
if (r > mid)
update(rs(x), mid + 1, nr, l, r, k);
pushup(x);
}
ll query(int x, int nl, int nr, int l, int r) {
if (l <= nl && nr <= r)
return s[x];
pushdown(x, nl, nr);
int mid = (nl + nr) >> 1;
if (r <= mid)
return query(ls(x), nl, mid, l, r);
else if (l > mid)
return query(rs(x), mid + 1, nr, l, r);
else
return query(ls(x), nl, mid, l, r) + query(rs(x), mid + 1, nr, l, r);
}
} // namespace SMT
void solve(int l, int r, int L, int R) {
if (L > R)
return;
if (l == r) {
for (int i = L; i <= R; ++i)
if (nd[i].id)
ans[nd[i].id] = l;
return;
}
int mid = (l + r) >> 1, lp = 0, rp = 0;
for (int i = L; i <= R; ++i)
if (nd[i].id) {
ll res = SMT::query(1, 1, n, nd[i].l, nd[i].r);
if (res < nd[i].k)
nd[i].k -= res, ndl[lp++] = nd[i];
else
ndr[rp++] = nd[i];
} else {
if (nd[i].k <= mid)
ndl[lp++] = nd[i];
else
SMT::update(1, 1, n, nd[i].l, nd[i].r, 1), ndr[rp++] = nd[i];
}
for (int i = L; i <= R; ++i)
if (!nd[i].id && nd[i].k > mid)
SMT::update(1, 1, n, nd[i].l, nd[i].r, -1);
memcpy(nd + L, ndl, sizeof(Node) * lp);
memcpy(nd + L + lp, ndr, sizeof(Node) * rp);
solve(l, mid, L, L + lp - 1), solve(mid + 1, r, L + lp, R);
}
signed main() {
n = read(), m = read();
for (int i = 1; i <= m; ++i) {
nd[i].id = (read() == 1 ? 0 : ++cntq);
nd[i].l = read(), nd[i].r = read(), nd[i].k = read<ll>();
}
solve(-n, n, 1, m);
for (int i = 1; i <= cntq; ++i)
printf("%d\n", ans[i]);
return 0;
}
给出一个 \(n \times n\) 的矩阵,\(q\) 次询问一个子矩形的 \(k\) 小值。
\(n \leq 500, q \leq 6 \times 10^4\)
用二维树状数组维护答案与 \(mid\) 的关系即可,时间复杂度 \(O((n^2 + q) \log^2 n \log V)\) 。
#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 5e2 + 7, M = 6e4 + 7;
struct Node {
int x, y, xx, yy, k, id;
} nd[N * N + M], ndl[N * N + M], ndr[N * N + M];
int ans[M];
int n, m, tot;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
namespace BIT {
int c[N][N];
inline void update(int x, int y, int k) {
for (int i = x; i <= n; i += i & -i)
for (int j = y; j <= n; j += j & -j)
c[i][j] += k;
}
inline int ask(int x, int y) {
int res = 0;
for (int i = x; i; i -= i & -i)
for (int j = y; j; j -= j & -j)
res += c[i][j];
return res;
}
inline int query(int x, int y, int xx, int yy) {
return ask(xx, yy) - ask(x - 1, yy) - ask(xx, y - 1) + ask(x - 1, y - 1);
}
} // namespace BIT
void solve(int l, int r, int L, int R) {
if (L > R)
return;
if (l == r) {
for (int i = L; i <= R; ++i)
if (nd[i].id)
ans[nd[i].id] = l;
return;
}
int mid = (l + r) >> 1, ql = 0, qr = 0;
for (int i = L; i <= R; ++i)
if (nd[i].id) {
int res = BIT::query(nd[i].x, nd[i].y, nd[i].xx, nd[i].yy);
if (nd[i].k <= res)
ndl[ql++] = nd[i];
else
nd[i].k -= res, ndr[qr++] = nd[i];
} else {
if (nd[i].k <= mid)
BIT::update(nd[i].x, nd[i].y, 1), ndl[ql++] = nd[i];
else
ndr[qr++] = nd[i];
}
for (int i = L; i <= R; ++i)
if (!nd[i].id && nd[i].k <= mid)
BIT::update(nd[i].x, nd[i].y, -1);
memcpy(nd + L, ndl, sizeof(Node) * ql);
memcpy(nd + L + ql, ndr, sizeof(Node) * qr);
solve(l, mid, L, L + ql - 1), solve(mid + 1, r, L + ql, R);
}
signed main() {
n = read(), m = read();
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= n; ++j)
nd[++tot].x = i, nd[tot].y = j, nd[tot].k = read();
for (int i = 1; i <= m; ++i) {
nd[++tot].x = read(), nd[tot].y = read();
nd[tot].xx = read(), nd[tot].yy = read();
nd[tot].k = read(), nd[tot].id = i;
}
solve(-inf, inf, 1, tot);
for (int i = 1; i <= m; ++i)
printf("%d\n", ans[i]);
return 0;
}
求解单调序列
给定一个序列,每次操作可以把某个数 \(+1\) 或 \(-1\) 。要求把序列变成非降数列,求最小操作次数。
\(n \leq 5 \times 10^5\)
事实上在满足操作次数最小化的前提下,一定存在一种方案使得最后序列中的每个数都是序列修改前存在的,这个结论可以使用数学归纳法证明。
由于要求最终的序列单调不降,可以使用整体二分。每轮整体二分判定最终序列区间 \([L, R]\) 的值域,此时答案的值域为 \([l, r]\) 。每轮二分开始时默认将所有数划分到 \([mid + 1, r]\) ,即划分到 \([l, mid]\) 的数设为 \(0\) 个。初始代价设为将序列区间 \([L, R]\) 全部置为 \(mid + 1\) 的操作次数。
依次枚举 \([L, R]\) 中的数 \(i\) 并且计算将 \([L, i]\) 置为 \(mid\) 、将 \([i + 1, R]\) 置为 \(mid + 1\) 的操作次数之和,如果优于之前的操作次数则更新最少操作次数和要划分到 \([l, mid]\) 的数的个数。
正确性:划分时已经保证了最终序列的单调性不被破坏,同时因为每次都取最小操作次数,最终被划分至左区间的数取 \(mid\) 一定比取 \(mid + 1\) 更优。
时间复杂度 \(O(n \log V)\) 。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 5e5 + 7;
int a[N], b[N];
int n;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
void solve(int l, int r, int L, int R) {
if (L > R || l == r)
return;
int mid = (l + r) >> 1;
ll sum = 0;
for (int i = L; i <= R; ++i)
sum += abs(a[i] - mid - 1);
ll mn = sum;
int mnpos = L - 1;
for (int i = L; i <= R; ++i) {
sum += abs(a[i] - mid) - abs(a[i] - mid - 1);
if (sum < mn)
mn = sum, mnpos = i;
}
fill(b + L, b + mnpos + 1, mid), fill(b + mnpos + 1, b + R + 1, mid + 1);
solve(l, mid, L, mnpos), solve(mid + 1, r, mnpos + 1, R);
}
signed main() {
n = read();
for (int i = 1; i <= n; ++i)
a[i] = read();
solve(-inf, inf, 1, n);
ll ans = 0;
for (int i = 1; i <= n; ++i)
ans += abs(a[i] - b[i]);
printf("%lld", ans);
return 0;
}
P4331 [BalticOI 2004] Sequence 数字序列
给定一个整数序列 \(a_{1 \sim n}\),求出一个严格递增序列 \(b_{1 \sim n}\),使得 \(\sum_{i = 1}^n |a_i - b_i|\) 最小。
\(n \leq 10^6\)
和上题类似,但是需要输出方案。上题解决了非严格递增的情况,转到严格递增有一个 trick 就是先将 \(a_i\) 赋值为 \(a_i - i\) ,最后将答案 \(b_i\) 赋值为 \(b_i + i\) 。
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e6 + 7;
int a[N], b[N];
int n;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
void solve(ll l, ll r, int L, int R) {
if (L > R || l == r)
return;
int mid = (l + r) >> 1;
ll sum = 0;
for (int i = L; i <= R; ++i)
sum += abs(a[i] - mid - 1);
ll mn = sum;
int mnpos = L - 1;
for (int i = L; i <= R; ++i) {
sum += abs(a[i] - mid) - abs(a[i] - mid - 1);
if (sum < mn)
mn = sum, mnpos = i;
}
fill(b + L, b + mnpos + 1, mid), fill(b + mnpos + 1, b + R + 1, mid + 1);
solve(l, mid, L, mnpos), solve(mid + 1, r, mnpos + 1, R);
}
signed main() {
n = read();
for (int i = 1; i <= n; ++i)
a[i] = read() - i;
solve(0, 1ll << 31, 1, n);
ll ans = 0;
for (int i = 1; i <= n; ++i)
ans += abs(a[i] - b[i]);
printf("%lld\n", ans);
for (int i = 1; i <= n; ++i)
printf("%d ", b[i] + i);
return 0;
}
维护不可加贡献
策略为每次处理 \(solve(l, r, L, R)\) 时,先执行 \([l, mid]\) 的修改,将 \([L, R]\) 分为两部分后先不清空递归右半部分,再撤销 \([l, mid]\) 的修改递归左半部分。
给定 \(n\) 个点的无向图,依次加入 \(m\) 条无向带权边,每次加入后询问是否存在一个边集,满足每个点的度数均为奇数,若存在则还需要最小化边集中的最大边权。
\(n \leq 10^5, m \leq 3 \times 10^5\)
首先有一个结论:存在合法边集当且仅当所有连通块大小均为偶数。
必要性:连通块大小为奇数时若存在方案,则保留合法边集后此连通块度数之和为奇数,矛盾。
充分性:每个联通块内仅保留一棵生成树,然后从叶子开始,一个点与其父亲的连边保留当且仅当这个点与其所有儿子的连边数为偶数,那么就可以构造出来了。
先考虑无修改的情况:连通块大小均为偶数时,再添加一些边后依然满足条件,所以按边权从小到大排序后,有用的边一定是一个前缀,并且具有单调性。而此题带修改、多询问,且答案单调不增,自然想到整体二分。
用可撤销并查集维护连通块,按照上述的方法分治即可。
令 solve(l, r, L, R)
表示 \([L, R]\) 的答案 \(\in [l, r]\) 。注意到每次分治时编号 \(< L\) 且权值 \(\leq l\) 的边一定被考虑,故考虑保证每次分治时这些边已经加入并查集。
令 \(mid = \dfrac{l + r}{2}\) ,先加入权值 \(\leq mid\) 且编号 \(< L\) 的必须边,然后依次加入权值 \(\leq mid\) 且未考虑的边,记第一个合法的位置为 \(cur\) ,则 \(ans_{cur} \leq mid, ans_{cur - 1} > mid\) ,递归分治即可。
时间复杂度 \(O(m \log m \log n)\) 。
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7, M = 3e5 + 7;
struct Edge {
int u, v, w;
} e[M], g[M];
struct DSU {
int fa[N], siz[N], sta[N];
int top, odd;
inline void prework(int n) {
iota(fa + 1, fa + 1 + n, 1);
fill(siz + 1, siz + 1 + n, 1);
odd = n;
}
inline int find(int x) {
while (x != fa[x])
x = fa[x];
return x;
}
inline void merge(int x, int y) {
x = find(x), y = find(y);
if (x == y)
return;
if (siz[x] < siz[y])
swap(x, y);
sta[++top] = y;
if ((siz[x] & 1) && (siz[y] & 1))
odd -= 2;
fa[y] = x, siz[x] += siz[y];
}
inline void restore(int k) {
while (top > k) {
int y = sta[top--], x = fa[y];
fa[y] = y, siz[x] -= siz[y];
if ((siz[x] & 1) && (siz[y] & 1))
odd += 2;
}
}
} dsu;
int id[M], ans[M];
int n, m;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
void solve(int l, int r, int L, int R) {
if (L > R || l > r)
return;
if (l == r) {
fill(ans + L, ans + R + 1, e[id[l]].w);
return;
}
int mid = (l + r) >> 1, oritop = dsu.top;
for (int i = l; i <= mid; ++i)
if (id[i] < L)
dsu.merge(e[id[i]].u, e[id[i]].v);
int cur = R + 1, pretop = dsu.top;
for (int i = L; i <= R; ++i) {
if (e[i].w <= e[id[mid]].w)
dsu.merge(e[i].u, e[i].v);
if (!dsu.odd) {
cur = i;
break;
}
}
dsu.restore(pretop);
solve(mid + 1, r, L, cur - 1);
dsu.restore(oritop);
for (int i = L; i < cur; ++i)
if (e[i].w <= e[id[l]].w)
dsu.merge(e[i].u, e[i].v);
solve(l, mid, cur, R);
dsu.restore(oritop);
}
signed main() {
n = read(), m = read();
for (int i = 1; i <= m; ++i)
e[i].u = read(), e[i].v = read(), e[i].w = read();
iota(id + 1, id + 1 + m, 1);
sort(id + 1, id + 1 + m, [](const int &a, const int &b) { return e[a].w < e[b].w; });
e[m + 1].w = -1, id[m + 1] = m + 1, dsu.prework(n);
solve(1, m + 1, 1, m);
for (int i = 1; i <= m; ++i)
printf("%d\n", ans[i]);
return 0;
}
求解一般性问题
给定一棵树,\(m\) 次操作:
- 向路径集合中加入路径 \(x \to y\) ,权值为 \(k\) 。
- 向路径集合中删除第 \(x\) 条路径。
- 求路径集合中所有不经过 \(u\) 的路径的权值最大值。
\(n \leq 10^5, m \leq 2 \times 10^5\)
考虑整体二分,如果某个询问点被所有大于 \(mid\) 的路径所经过,那么答案 \(\leq mid\) ,否则答案 \(> mid\) 。查询经过一个点的路径条数用树上差分即可。
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7, M = 2e5 + 7;
struct Graph {
vector<int> e[N];
inline void insert(int u, int v) {
e[u].emplace_back(v);
}
} G;
struct Node {
int op, x, y, k, id;
} nd[M], ndl[M], ndr[M];
int fa[N], dep[N], siz[N], son[N], top[N], dfn[N], ans[M];
int n, m, cntq, dfstime;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
void dfs1(int u, int f) {
fa[u] = f, dep[u] = dep[f] + 1, siz[u] = 1;
for (int v : G.e[u]) {
if (v == f)
continue;
dfs1(v, u), siz[u] += siz[v];
if (siz[v] > siz[son[u]])
son[u] = v;
}
}
void dfs2(int u, int topf) {
top[u] = topf, dfn[u] = ++dfstime;
if (son[u])
dfs2(son[u], topf);
for (int v : G.e[u])
if (v != fa[u] && v != son[u])
dfs2(v, v);
}
inline int LCA(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]])
swap(x, y);
x = fa[top[x]];
}
return dep[x] < dep[y] ? x : y;
}
namespace BIT {
int c[N];
inline void update(int x, int k) {
for (; x <= n; x += x & -x)
c[x] += k;
}
inline int ask(int x) {
int res = 0;
for (; x; x -= x & -x)
res += c[x];
return res;
}
inline int query(int l, int r) {
return ask(r) - ask(l - 1);
}
} // namespace BIT
inline void update(int x, int y, int k) {
BIT::update(dfn[x], k), BIT::update(dfn[y], k);
int lca = LCA(x, y);
BIT::update(dfn[lca], -k);
if (fa[lca])
BIT::update(dfn[fa[lca]], -k);
}
void solve(int l, int r, int L, int R) {
if (L > R)
return;
if (l == r) {
for (int i = L; i <= R; ++i)
if (nd[i].op == 2)
ans[nd[i].id] = l;
return;
}
int mid = (l + r) >> 1, ql = 0, qr = 0;
for (int i = L, sum = 0; i <= R; ++i)
if (nd[i].op == 2) {
if (BIT::query(dfn[nd[i].x], dfn[nd[i].x] + siz[nd[i].x] - 1) == sum)
ndl[ql++] = nd[i];
else
ndr[qr++] = nd[i];
} else {
if (nd[i].k <= mid)
ndl[ql++] = nd[i];
else {
ndr[qr++] = nd[i], sum += nd[i].op;
update(nd[i].x, nd[i].y, nd[i].op);
}
}
for (int i = L; i <= R; ++i)
if (nd[i].op != 2 && nd[i].k > mid)
update(nd[i].x, nd[i].y, -nd[i].op);
memcpy(nd + L, ndl, sizeof(Node) * ql);
memcpy(nd + L + ql, ndr, sizeof(Node) * qr);
solve(l, mid, L, L + ql - 1), solve(mid + 1, r, L + ql, R);
}
signed main() {
n = read(), m = read();
for (int i = 1; i < n; ++i) {
int u = read(), v = read();
G.insert(u, v), G.insert(v, u);
}
dfs1(1, 0), dfs2(1, 1);
for (int i = 1; i <= m; ++i) {
nd[i].op = read();
if (!nd[i].op)
nd[i].op = 1, nd[i].x = read(), nd[i].y = read(), nd[i].k = read();
else if (nd[i].op == 1)
nd[i] = nd[read()], nd[i].op = -1;
else
nd[i].x = read(), nd[i].id = ++cntq;
}
solve(-1, 1e9, 1, m);
for (int i = 1; i <= cntq; ++i)
printf("%d\n", ans[i]);
return 0;
}
给出一个分为 \(m\) 段的环形序列与 \(n\) 个国家,第 \(i\) 段属于国家 \(o_i\) 。有 \(k\) 次事件,每次给环形序列上的一个区间加上一个正整数。每个国家有一个期望 \(p_i\) ,求出每个国家在序列上所有位置的值的和到达 \(p_i\) 的最早时间,或报告无法达到。
\(n, m, k \leq 3 \times 10^5\)
思维难度不算太高的整体二分。
#include <bits/stdc++.h>
typedef unsigned long long ull;
using namespace std;
const int N = 3e5 + 7;
struct Update {
int l, r, k, id;
} upd[N << 1];
vector<int> collect[N];
int o[N], p[N], id[N], ans[N], idl[N], idr[N];
int n, m, q, tot;
template <class T = int>
inline T read() {
char c = getchar();
bool sign = (c == '-');
while (c < '0' || c > '9')
c = getchar(), sign |= (c == '-');
T x = 0;
while ('0' <= c && c <= '9')
x = (x << 1) + (x << 3) + (c & 15), c = getchar();
return sign ? (~x + 1) : x;
}
namespace BIT {
ull c[N];
inline void modify(int x, int k) {
for (; x <= m; x += x & -x)
c[x] += k;
}
inline void update(int l, int r, int k) {
modify(l, k), modify(r + 1, -k);
}
inline ull query(int x) {
ull res = 0;
for (; x; x -= x & -x)
res += c[x];
return res;
}
} // namespace BIT
void solve(int l, int r, int L, int R) {
if (L > R)
return;
if (l == r) {
for (int i = L; i <= R; ++i)
ans[id[i]] = upd[l].id;
return;
}
int mid = (l + r) >> 1;
for (int i = l; i <= mid; ++i)
BIT::update(upd[i].l, upd[i].r, upd[i].k);
int ql = 0, qr = 0;
for (int i = L; i <= R; ++i) {
ull res = 0;
for (int x : collect[id[i]])
res += BIT::query(x);
if (p[id[i]] <= res)
idl[ql++] = id[i];
else
p[id[i]] -= res, idr[qr++] = id[i];
}
for (int i = l; i <= mid; ++i)
BIT::update(upd[i].l, upd[i].r, -upd[i].k);
memcpy(id + L, idl, sizeof(int) * ql);
memcpy(id + L + ql, idr, sizeof(int) * qr);
solve(l, mid, L, L + ql - 1), solve(mid + 1, r, L + ql, R);
}
signed main() {
n = read(), m = read();
for (int i = 1; i <= m; ++i)
collect[read()].emplace_back(i);
for (int i = 1; i <= n; ++i)
p[i] = read();
q = read();
for (int i = 1; i <= q; ++i) {
int l = read(), r = read(), k = read();
if (l <= r)
upd[++tot] = (Update) {l, r, k, i};
else {
upd[++tot] = (Update) {l, m, k, i};
upd[++tot] = (Update) {1, r, k, i};
}
}
upd[++tot].id = q + 1;
iota(id + 1, id + 1 + n, 1);
solve(1, tot, 1, n);
for (int i = 1; i <= n; ++i)
if (ans[i] == q + 1)
puts("NIE");
else
printf("%d\n", ans[i]);
return 0;
}