CF1034D Intervals of Intervals
CF1034D Intervals of Intervals
题目大意
给定数轴上的 \(n\) 条线段 \([a_i, b_i]\)(\(1\leq i\leq n\),\(1\leq a_i < b_i\leq 10^9\))。
定义区间 \([l, r]\)(\(1\leq l\leq r\leq n\))的价值为 \(l,\dots ,r\) 这些线段的并的长度。
请你选择 \(k\) 个不同的区间,最大化所选区间的价值之和,并输出这个最大值。
数据范围:\(1\leq n\leq 3\times 10^5\),\(1\leq k\leq \min\{\frac{n(n + 1)}{2},10^9\}\)。
本题题解
考虑一个子问题:有 \(n\) 条线段,\(q\) 次询问,每次问一个区间 \([l, r]\)(\(1\leq l\leq r\leq n\))里的线段的并的长度。
可以离线,按编号从 \(1\) 到 \(n\) 考虑每条线段,假设当前线段 \(i\) 是询问的右端点,我们维护每个左端点的答案。
与此同时,对数轴上所有位置 \(p\),维护最后一个覆盖到它的线段编号 \(\mathrm{lst}(p)\),特别地,如果位置 \(p\) 没有被覆盖过,认为 \(\mathrm{lst}(p) = 0\)。那么一次询问 \([l,i]\) 的答案就是 \(\mathrm{lst}\) 值大于等于 \(l\) 的位置有多少个。
加入线段 \(i\),相当于把 \([a_i, b_i]\) 这段区间里所有位置的 \(\mathrm{lst}\) 值改为 \(i\)。假设一个位置原本的 \(\mathrm{lst}\) 值为 \(i'\),那么 \([i' + 1, i]\) 里的左端点,答案都要增加 \(1\)。
把相邻、且 \(\mathrm{lst}\) 值相同的位置,缩成一条线段。那么只需要用 \(\texttt{std::set}\) 维护若干条线段。每个 \(i\) 只会带来新增 \(\mathcal{O}(1)\) 条线段(虽然可能它会一次删去很多线段,但这个复杂度是均摊的,所以暴力删就行)。删除一个原有线段,对答案数组(每个左端点的答案)带来的修改是区间加,用线段树维护答案数组即可。
总时间复杂度 \(\mathcal{O}(n\log n)\)。另外提一句,把线段树的修改可持久化一下,该问题也可以在线做。
回到本题,我们显然是选价值最大的 \(k\) 个区间。考虑二分价值第 \(k\) 大的区间的价值。设当前二分值为 \(\mathrm{mid}\),问题转化为求:价值大于等于 \(\mathrm{mid}\) 的区间数量,根据数量 \(<\) 或 \(\geq k\),调整二分的上下界。
设区间 \([l, r]\) 的价值为 \(v(l, r)\)。那么显然有:\(v(l, r)\geq v(l + 1, r)\),\(v(l, r)\geq v(l, r - 1)\) (\(l < r\))。
我们仍然从 \(1\) 到 \(n\) 枚举右端点。对于同一个右端点 \(i\),设每个左端点 \(l\) 的答案为 \(f_i(l)\),即 \(f_i(l) = v(l, i)\)。那么有:\(f_i(1)\geq f_i(2)\geq \dots\geq f_i(i)\)。
记最后一个满足 \(f_i(l)\geq \mathrm{mid}\) 的 \(l\) 为 \(L(i)\)。特别地,如果 \(f_i(1) < \mathrm{mid}\),\(L_i(i) = 0\)。我们要求的就是 \(\sum_{i = 1}^{n}L(i)\)。如果用线段树维护 \(f\) 数组(支持区间加),那么 \(L(i)\) 可以在线段树上二分求出来,不过这样总时间复杂度是 \(\mathcal{O}(n\log^2 n)\) 的。还可以继续优化。
考虑到 \(L(1)\leq L(2)\leq \dots\leq L(n)\),所以可以在枚举 \(i\) 的同时,用 \(\text{two pointers}\) 的方法维护 \(L(i)\)。在二分之前,先预处理出要做的区间加法(也就是把 \(\texttt{set}\) 的 \(\log\) 放在二分外面),然后用差分来实现区间加。
时间复杂度 \(\mathcal{O}(n\log n)\)。
有一个要注意的细节是,二分出第 \(k\) 大的价值 \(v\) 后,答案并不一定是所有 \(\geq v\) 的价值之和。因为价值等于 \(v\) 的区间可能有多个,超过 \(k\) 个的部分我们不能取。
参考代码
实际提交时建议使用读入、输出优化,详见本博客公告。
// problem: CF1034D
#include <bits/stdc++.h>
using namespace std;
#define mk make_pair
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }
const int MAXN = 3e5;
const int MAXR = 1e9;
int n, K;
pii seg[MAXN + 5];
vector<pii> ev[MAXN + 5];
#define mk3(x, y, z) mk(mk(x, y), z)
#define IT set<pair<pii, int> > :: iterator
set<pair<pii, int> > s;
void insert(int sl, int sr, int sid) {
IT it = s.lower_bound(mk3(sl, 0, 0));
if (it != s.begin()) {
--it;
if (it -> fi.se > sl) {
pair<pii, int> tmp = *it;
s.erase(it);
s.insert(mk3(tmp.fi.fi, sl, tmp.se));
s.insert(mk3(sl, tmp.fi.se, tmp.se));
}
}
it = s.lower_bound(mk3(sr, 0, 0));
if (it != s.begin()) {
--it;
if (it -> fi.se > sr) {
pair<pii, int> tmp = *it;
s.erase(it);
s.insert(mk3(tmp.fi.fi, sr, tmp.se));
s.insert(mk3(sr, tmp.fi.se, tmp.se));
}
}
IT bg = s.lower_bound(mk3(sl, 0, 0));
IT ed = bg;
if (bg != s.end() && bg -> fi.se <= sr) {
int lst = sl;
do {
int cl = ed -> fi.fi;
int cr = ed -> fi.se;
if (cl != lst)
ev[sid].push_back(mk(1, cl - lst));
ev[sid].push_back(mk(ed -> se + 1, cr - cl));
lst = cr;
++ed;
} while(ed != s.end() && ed -> fi.se <= sr);
if (lst != sr)
ev[sid].push_back(mk(1, sr - lst));
s.erase(bg, ed);
} else {
ev[sid].push_back(mk(1, sr - sl));
}
s.insert(mk3(sl, sr, sid));
// cerr << "**********" << endl;
// for (it = s.begin(); it != s.end(); ++it) {
// cerr << (it -> fi.fi) << " " << (it -> fi.se) << endl;
// }
}
#undef IT
#undef mk3
int f[MAXN + 5];
ll calc(int min_size) {
// 并的大小 >= min_size 的区间有几个
// 时间复杂度 O(n)
for (int i = 1; i <= n; ++i) {
f[i] = 0; // 清空
}
int l = 0;
ll res = 0;
for (int i = 1; i <= n; ++i) {
// 右端点: i
for (int j = 0; j < SZ(ev[i]); ++j) {
int p = ev[i][j].fi;
int v = ev[i][j].se;
ckmax(p, l);
// 对 l = p...i 的答案, 加上 v
f[p] += v;
f[i + 1] -= v;
}
// l: 使得 [l, i] 的答案 >= min_size 的最大的 l
while (l + 1 <= i && f[l + 1] + f[l] >= min_size) {
f[l + 1] += f[l];
if (l + 1 > 1)
assert(f[l + 1] <= f[l]);
++l;
}
// cerr << "** " << l << endl;
res += l;
}
return res;
}
ll calc_ans(int min_size) {
// 并的大小 >= min_size 的区间的答案之和
// 时间复杂度 O(n)
for (int i = 1; i <= n; ++i) {
f[i] = 0; // 清空
}
int l = 0;
ll sum = 0, res = 0;
for (int i = 1; i <= n; ++i) {
// 右端点: i
for (int j = 0; j < SZ(ev[i]); ++j) {
int p = ev[i][j].fi;
int v = ev[i][j].se; // 对 l = p...i 的答案, 加上 v
if (p <= l) {
sum += (ll)v * (l - p + 1);
}
ckmax(p, l);
f[p] += v;
f[i + 1] -= v;
}
// l: 使得 [l, i] 的答案 >= min_size 的最大的 l
while (l + 1 <= i && f[l + 1] + f[l] >= min_size) {
f[l + 1] += f[l];
sum += f[l + 1];
++l;
}
res += sum;
}
return res;
}
int main() {
cin >> n >> K;
for (int i = 1; i <= n; ++i) {
cin >> seg[i].fi >> seg[i].se;
}
for (int i = 1; i <= n; ++i) {
insert(seg[i].fi, seg[i].se, i);
}
int l = 0, r = MAXR;
while (l < r) {
int mid = (l + r + 1) >> 1;
if (calc(mid) >= K) {
l = mid;
} else {
r = mid - 1;
}
}
// cerr << l << endl;
cout << calc_ans(l + 1) + (ll)l * (K - calc(l + 1)) << endl;
return 0;
}