[题解] Luogu2487 [SDOI2011]拦截导弹
题目大意
有 \(n\) 颗导弹按顺序拦截,每个导弹有高度 \(h\) 与速度 \(v\) ,要求下一颗拦截的导弹比当前速度慢且高度低,问最多可以拦截多少导弹、在所有拦截最多方案中每颗导弹被拦截的概率。
思路
首先对于 \(h\) 和 \(v\) 进行离散化,方便后面操作。
这是一个带两个参数的 LIS 问题,可以想到用 DP 的思路,但是暴力 DP 的时间复杂度为 \(O(n^2)\)。
可以将问题转化为三维偏序,然后经典套路 CDQ 分治来优化解决。
但是由于 DP 的特殊顺序,必须处理完左区间之后及时地与右区间合并,最后处理右区间。然后第一问就可以解决了。
将第二问可以拆分成两个子问题:以当前导弹结尾的方案数、以当前导弹开始的方案数。就可以将两个方案数相乘,算出包含该颗导弹的方案数。
Code
#include <cstdio>
#include <algorithm>
using namespace std;
#define int __int128
//计算过程中会爆long long,用int128或者double存
void Read(int &n) {
n = 0; int op = 1; char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') op = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
n = (n << 1) + (n << 3) + (c ^ 48);
c = getchar();
}
n *= op;
}
void Write(int n) {
if(n < 0) putchar('-'), n = -n;
if(n >= 10) Write(n / 10);
putchar(n % 10 + 48);
}
const int MAXN = 2e5 + 5;
struct Seg {//线段树维护区间最大值
int l, r, maxn, cnt;//cnt为该区间的最大值有多少个,便于统计方案数
#define ls (pos << 1)
#define rs (pos << 1 | 1)
} t[MAXN << 2];
void Push_Up(int pos) {
t[pos].maxn = max(t[ls].maxn, t[rs].maxn);
t[pos].cnt = 0;
if (t[ls].maxn == t[pos].maxn) t[pos].cnt += t[ls].cnt;
if (t[rs].maxn == t[pos].maxn) t[pos].cnt += t[rs].cnt;
}
void Clear(int pos) {
if (!t[pos].maxn) return;
t[pos].maxn = 0;
if (t[pos].l == t[pos].r) return;
Clear(ls), Clear(rs);
}
void Build(int pos, int l, int r) {
t[pos].l = l, t[pos].r = r;
if (l == r) return;
int mid = (l + r) >> 1;
Build(ls, l, mid);
Build(rs, mid + 1, r);
}
void Update(int pos, int x, int y, int d) {
if (t[pos].l == t[pos].r) {
if (t[pos].maxn < y) t[pos].maxn = y, t[pos].cnt = d;
else if (t[pos].maxn == y) t[pos].cnt += d;
return;
}
if (x <= t[ls].r) Update(ls, x, y, d);
else Update(rs, x, y, d);
Push_Up(pos);
}
void Query(int pos, int l, int r, int &res1, int &res2) {
if (l <= t[pos].l && t[pos].r <= r) {
res1 = t[pos].maxn, res2 = t[pos].cnt;
return;
}
int maxn1 = 0, maxn2 = 0, cnt1 = 0, cnt2 = 0;
if (l <= t[ls].r) Query(ls, l, r, maxn1, cnt1);
if (r >= t[rs].l) Query(rs, l, r, maxn2, cnt2);
res1 = max(maxn1, maxn2), res2 = 0;
if (res1 == maxn1) res2 += cnt1;
if (res1 == maxn2) res2 += cnt2;
}
struct Node { int h, v, t; } a[MAXN];
bool cmpt(Node x, Node y) {
return x.t < y.t;
}
bool cmph(Node x, Node y) {
return x.h != y.h ? x.h > y.h : x.t < y.t;
}
int lsh[2][MAXN], tt[2];
int f1[MAXN], g1[MAXN], f2[MAXN], g2[MAXN], ans1, ans2, sum, n;
//f1为以i开头的最长不上升子序列,g1为方案数
//f2为以i结尾的最长不上升子序列,g2为方案数
void Solve1(int l, int r) {
if (l == r) return;
int mid = (l + r) >> 1;
sort(a + l, a + 1 + r, cmpt);
Solve1(l, mid);//先解决左区间
sort(a + l, a + 1 + mid, cmph);
sort(a + mid + 1, a + 1 + r, cmph);
Clear(1);
int i = l, j = mid + 1;
while (i <= mid && j <= r) {//合并左右区间
if (a[i].h >= a[j].h) {//第二维归并处理
Update(1, a[i].v, f1[a[i].t], g1[a[i].t]);//修改区间最大值便于DP
//第三维线段树处理
i++;
}
else {
int maxn = 0, cnt = 0;
Query(1, a[j].v, n, maxn, cnt);
if (maxn) {
if (f1[a[j].t] < maxn + 1) f1[a[j].t] = maxn + 1, g1[a[j].t] = cnt;//状态转移+统计方案数
else if (f1[a[j].t] == maxn + 1) g1[a[j].t] += cnt;
}
j++;
}
}
while (j <= r) {
int maxn = 0, cnt = 0;
Query(1, a[j].v, n, maxn, cnt);
if (maxn) {
if (f1[a[j].t] < maxn + 1) f1[a[j].t] = maxn + 1, g1[a[j].t] = cnt;
else if (f1[a[j].t] == maxn + 1) g1[a[j].t] += cnt;
}
j++;
}
Solve1(mid + 1, r);//最后解决右区间
}
void Solve2(int l, int r) {
if (l == r) return;
int mid = (l + r) >> 1;
sort(a + l, a + 1 + r, cmpt);
Solve2(l, mid);
sort(a + l, a + 1 + mid, cmph);
sort(a + mid + 1, a + 1 + r, cmph);
Clear(1);
int i = l, j = mid + 1;
while (i <= mid && j <= r) {
if (a[i].h >= a[j].h) {
Update(1, a[i].v, f2[a[i].t], g2[a[i].t]);
i++;
}
else {
int maxn = 0, cnt = 0;
Query(1, a[j].v, n, maxn, cnt);
if (maxn) {
if (f2[a[j].t] < maxn + 1) f2[a[j].t] = maxn + 1, g2[a[j].t] = cnt;
else if (f2[a[j].t] == maxn + 1) g2[a[j].t] += cnt;
}
j++;
}
}
while (j <= r) {
int maxn = 0, cnt = 0;
Query(1, a[j].v, n, maxn, cnt);
if (maxn) {
if (f2[a[j].t] < maxn + 1) f2[a[j].t] = maxn + 1, g2[a[j].t] = cnt;
else if (f2[a[j].t] == maxn + 1) g2[a[j].t] += cnt;
}
j++;
}
Solve2(mid + 1, r);
}
int Get(int x, int h) {
return lower_bound(lsh[h] + 1, lsh[h] + 1 + tt[h], x) - lsh[h];
}
signed main() {
Read(n); Build(1, 1, n);
for (int i = 1; i <= n; i++) {
Read(a[i].h); Read(a[i].v);
a[i].t = i;
lsh[0][++tt[0]] = a[i].h;
lsh[1][++tt[1]] = a[i].v;
}
sort(lsh[0] + 1, lsh[0] + 1 + tt[0]);
sort(lsh[1] + 1, lsh[1] + 1 + tt[1]);
tt[0] = unique(lsh[0] + 1, lsh[0] + 1 + tt[0]) - lsh[0] - 1;
tt[1] = unique(lsh[1] + 1, lsh[1] + 1 + tt[1]) - lsh[1] - 1;//离散化
for (int i = 1; i <= n; i++) a[i].h = Get(a[i].h, 0), a[i].v = Get(a[i].v, 1);
for (int i = 1; i <= n; i++) f1[i] = f2[i] = g1[i] = g2[i] = 1;//只有自己长度为1
Solve1(1, n);
for (int i = 1; i <= n; i++) ans1 = max(ans1, f1[i]);
for (int i = 1; i <= n; i++) if (f1[i] == ans1) sum += g1[i];
Write(ans1);
printf("\n");
for (int i = 1; i <= n; i++) {//倒着跑回去,得到第二个此问题的答案
a[i].t = n - a[i].t + 1;
a[i].h = n - a[i].h + 1;
a[i].v = n - a[i].v + 1;
}
sort(a + 1, a + 1 + n, cmpt);
Solve2(1, n);
for (int i = 1; i <= n; i++) {
if (f1[i] + f2[n - i + 1] - 1 != ans1) printf("0.00000 ");
else printf("%.5lf ", (double) (1.0 * g1[i] * g2[n - i + 1] / sum));
}
return 0;
}
```