【Luogu P4577】「FJOI2018」领导集团问题
Description
给出一棵大小为 \(n\) 的树,第 \(i\) 个节点的权值为 \(w_i\)。
你需要选出一个最大的节点集合,使得对于节点集合中不同的两个点 \(i, j\),若 \(i\) 为 \(j\) 的祖先节点,必须满足 \(w_i \leq w_j\)。
数据范围:\(1 \leq n \leq 2 \times 10^5\),\(1 \leq w_i \leq 10^9\)。
时空限制:\(1000 \ \mathrm{ms}/ 250 \ \mathrm{MiB}\)。
Solution
提供一种类似 [NOI2020] 命运 的整体 dp 做法。
考虑 dp。记 \(f(u, i)\) 表示:在子树 \(u\) 中选出的节点集合的 \(w\) 最小值为 \(i\) 的情况下,最大的节点集合的大小。有两种转移:
- 节点集合不包括 \(u\)(要保证 \(f(u, i)\) 至少有一个 \(f(v, i)\) 的转移)。
- 节点集合包括 \(u\)。
考虑线段树合并优化,在一棵维护子树 \(u\) 的线段树中,代表区间 \([l, r]\) 的节点需要维护在子树 \(u\) 中选出的节点集合的 \(w\) 最小值在区间 \([l, r]\) 中的情况下,最大的节点集合的大小,即 \(\max\limits_{l \leq i \leq r} f(u, i)\)。
转移 1 即为整体 dp 的重点,在线段树合并的过程中计算。具体地,在合并线段树 p, q
的过程中,设当前合并到了代表区间 [l, r]
的节点,在递归的过程中记录 \(\max\limits_{i > r} \{f(u, i)\}\) 与 \(\max\limits_{i > r} \{f(v, i)\}\) 并分别记作 maxp, maxq
。将情况分成以下五类讨论:
- 当 \(p = 0, q = 0\) 时:返回空节点 \(0\) 即可。
- 当 \(p \neq 0, q = 0\) 时:此时相当于对 \(p\) 内做一次区间加
maxq
,打上懒标记后返回 \(p\) 即可。 - 当 \(p = 0, q \neq 0\) 时:此时相当于对 \(q\) 内做一次区间加
maxp
,打上懒标记后返回 \(q\) 即可。 - 当 \(l = r\) 时:此时递归到了一个叶子节点,有 \(f'(u, l) = \max(\mathtt{maxp}, f(u, l)) + \max(\mathtt{maxq}, f(v, l))\)。
- 当 \(p \neq 0, q \neq 0\) 时:此时需要先合并 \(p, q\) 的左右儿子,再以通过左右儿子上传信息。
特别要注意的是,区间加不用也不能对空节点打标记,因为空节点不能保证 \(f(u, i)\) 至少有一个 \(f(v, i)\) 的转移。
转移 2 即为平凡的区间查询最大值,单点修改。
时间复杂度 \(\mathcal{O}(n \log w)\),空间复杂度 \(\mathcal{O}(n \log w)\)。
#include <cstdio>
#include <cstring>
#include <algorithm>
template <class T>
inline void read(T &x) {
static char s;
while (s = getchar(), s < '0' || s > '9');
x = s - '0';
while (s = getchar(), s >= '0' && s <= '9') x = x * 10 + s - '0';
}
const int N = 200100;
const int sup = 1e9;
int n;
int a[N];
int tot, head[N], ver[N], Next[N];
void add_edge(int u, int v) {
ver[++ tot] = v; Next[tot] = head[u]; head[u] = tot;
}
namespace SGT {
const int pond = 10001000;
int nClock, root[N];
struct node {
int lc, rc;
int max;
int add;
void mk_add(int x) {
max += x;
add += x;
}
} t[pond];
void spread(int p) {
if (t[p].add) {
if (t[p].lc) t[t[p].lc].mk_add(t[p].add);
if (t[p].rc) t[t[p].rc].mk_add(t[p].add);
t[p].add = 0;
}
}
void insert(int &p, int l, int r, int x, int val) {
if (!p) p = ++ nClock;
t[p].max = std::max(t[p].max, val);
if (l == r) return;
spread(p);
int mid = (l + r) >> 1;
if (x <= mid)
insert(t[p].lc, l, mid, x, val);
else
insert(t[p].rc, mid + 1, r, x, val);
}
int ask(int p, int l, int r, int s, int e) {
if (s <= l && r <= e) return t[p].max;
spread(p);
int mid = (l + r) >> 1;
int cur = 0;
if (s <= mid)
cur = std::max(cur, ask(t[p].lc, l, mid, s, e));
if (mid < e)
cur = std::max(cur, ask(t[p].rc, mid + 1, r, s, e));
return cur;
}
int merge(int p, int q, int l, int r, int maxp, int maxq) {
if (!p && !q) return 0;
if (p && !q) return t[p].mk_add(maxq), p;
if (!p && q) return t[q].mk_add(maxp), q;
if (l == r)
return t[p].max = std::max(t[p].max, maxp) + std::max(t[q].max, maxq), p;
spread(p), spread(q);
int mid = (l + r) >> 1;
t[p].lc = merge(t[p].lc, t[q].lc, l, mid,
std::max(t[t[p].rc].max, maxp), std::max(t[t[q].rc].max, maxq));
t[p].rc = merge(t[p].rc, t[q].rc, mid + 1, r, maxp, maxq);
t[p].max = std::max(t[t[p].lc].max, t[t[p].rc].max);
return p;
}
}
void dp(int u) {
int val = 1;
for (int i = head[u]; i; i = Next[i]) {
int v = ver[i];
dp(v);
val += SGT::ask(SGT::root[v], 1, sup, a[u], sup);
SGT::root[u] = SGT::merge(SGT::root[u], SGT::root[v], 1, sup, 0, 0);
}
SGT::insert(SGT::root[u], 1, sup, a[u], val);
}
int main() {
read(n);
for (int i = 1; i <= n; i ++) read(a[i]);
for (int i = 2, fa; i <= n; i ++) {
read(fa);
add_edge(fa, i);
}
dp(1);
printf("%d\n", SGT::t[SGT::root[1]].max);
return 0;
}