[FJOI2018]领导集团问题
首先本题贪心不是很好做,可以考虑 \(dp\)。
然后我们有了一个很直接的想法,令 \(dp_{i, j}\) 表示以 \(i\) 号点为根当前选择的节点中权值最小的权值为 \(j\) 的最大成员数,可以发现这样做是 \(O(n ^ 3)\) 的。可以发现这个 \(dp\) 有很多转移是相同的,那么我们这样做是非常浪费的,为了能够快速转移,我们可以改变一下状态,令 \(dp_{i, j}\) 为以 \(i\) 为根的子树内选择的所有点权值都不小于 \(j\) 的最大成员数,可以发现这个 \(dp\) 的值从后往前是不断递增的,那么对于每颗子树的转移我们就直接有 \(dp_{u, i} = \sum dp_{v, i}\),如果在这个集合内选择 \(u\) 则还有转移 \(dp_{u, i} = \max\{dp_{u, i}, dp_{u, w_u} + 1\}(i \le w_u)\),这样就可以做到 \(O(n ^ 2)\) 了。
下面可以考虑优化这个 \(dp\),可以发现前面的那个转移方程就是将子树内每个位置上的值简单相加,这很类似线段树合并的流程,我们可以从线段树合并的角度考虑。那么对于第二个转移方程,因为对于每个 \(dp_i\) 中的每个值 \(dp_{i, j}\) 是从后往前不断递增的,那么相当于我们需要将当前节点线段树一段区间加上 \(1\) 即可。于是我们可以直接标记永久化解决,但实际上因为我们是将一个区间加上同一个值,且只需要最后查询一次权值,可以考虑从后往前差分这个 \(dp\) 值,那么我们发现只需要在 \(w_i\) 这个位置加 \(1\),在 \(i\) 左边第一个不为 \(0\) 的位置减 \(1\) 即可。具体的这个位置可以使用线段树二分求出,一个比较聪明的实现方法是我们查询 \(1 \sim w_i\) 的权值为 \(S\),然后在线段树上二分到第一个 \(1 \sim w_i\) 为 \(S\) 的位置即可。
一些坑点:
-
动态开点的 \(num\) 容易写成其他类似 \(tot, cnt\) 之类的变量
-
调用线段树合并之类的一定要传入 \(rt\) 而不是直接传节点编号
#include<bits/stdc++.h>
using namespace std;
#define N 200000 + 5
#define M 4000000 + 5
#define ls t[p].l
#define rs t[p].r
#define mid (l + r >> 1)
#define rep(i, l, r) for(int i = l; i <= r; ++i)
#define Next(i, u) for(int i = h[u]; i; i = e[i].next)
struct edge{
int v, next;
}e[N << 1];
struct tree{
int l, r, sum;
}t[M];
int n, u, ok, tot, cnt, num, d[N], w[N], h[N], rt[N];
int read(){
char c; int x = 0, f = 1;
c = getchar();
while(c > '9' || c < '0'){ if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
void add(int u, int v){
e[++tot].v = v, e[tot].next = h[u], h[u] = tot;
e[++tot].v = u, e[tot].next = h[v], h[v] = tot;
}
void update(int &p, int l, int r, int x, int y, int k){
if(!p) p = ++num;
if(l >= x && r <= y){ t[p].sum += k; return;}
if(mid >= x) update(ls, l, mid, x, y, k);
if(mid < y) update(rs, mid + 1, r, x, y, k);
t[p].sum = t[ls].sum + t[rs].sum;
}
int query(int p, int l, int r, int x, int y){
if(!p) return 0;
if(l >= x && r <= y) return t[p].sum;
int ans = 0;
if(mid >= x) ans += query(ls, l, mid, x, y);
if(mid < y) ans += query(rs, mid + 1, r, x, y);
return ans;
}
int find(int p, int l, int r, int k){
if(l == r) return !t[p].sum ? 0 : l;
if(t[ls].sum >= k) return find(ls, l, mid, k);
else return find(rs, mid + 1, r, k - t[ls].sum);
}
void Merge(int &p, int k, int l, int r){
if(!p || !k){ p = p + k; return;}
if(l == r){ t[p].sum += t[k].sum; return;}
Merge(ls, t[k].l, l, mid), Merge(rs, t[k].r, mid + 1, r);
t[p].sum = t[ls].sum + t[rs].sum;
}
void dfs(int u, int fa){
Next(i, u){
int v = e[i].v; if(v == fa) continue;
dfs(v, u), Merge(rt[u], rt[v], 1, cnt);
}
int val = query(rt[u], 1, cnt, 1, w[u]);
int pos = find(rt[u], 1, cnt, val);
update(rt[u], 1, cnt, w[u], w[u], 1);
if(pos) update(rt[u], 1, cnt, pos, pos, -1);
}
int main(){
n = read();
rep(i, 1, n) w[i] = d[i] = read();
sort(d + 1, d + n + 1);
cnt = unique(d + 1, d + n + 1) - d - 1;
rep(i, 1, n) w[i] = lower_bound(d + 1, d + cnt + 1, w[i]) - d;
rep(i, 2, n) u = read(), add(u, i);
dfs(1, 0);
printf("%d", t[rt[1]].sum);
return 0;
}
实际上还有可以使用另一种方式来维护这个差分数组,那就是平衡树启发式合并,因为有 \(set\) 的存在这样非常好写。具体的我们合并儿子差分数组直接启发式合并暴力插入,由于那个需要差分单点修改的原因,为了方便我们在 \(set\) 中存入的每个元素 \(x\) 表示在 \(x\) 这个位置上的差分数组 \(+1\),那么我们的直接每次插入 \(w_i\) 这个元素,找到第一个小于 \(w_i\) 的这个元素将这个元素删除即可。最终的答案就是 \(1\) 号点 \(set\) 的大小。
#include<bits/stdc++.h>
using namespace std;
#define N 200000 + 5
#define rep(i, l, r) for(int i = l; i <= r; ++i)
#define Next(i, u) for(int i = h[u]; i; i = e[i].next)
struct edge{
int v, next;
}e[N << 1];
multiset <int> S[N];
multiset <int> :: iterator it;
int n, u, tot, cnt, h[N], d[N], w[N];
int read(){
char c; int x = 0, f = 1;
c = getchar();
while(c > '9' || c < '0'){ if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
void add(int u, int v){
e[++tot].v = v, e[tot].next = h[u], h[u] = tot;
e[++tot].v = u, e[tot].next = h[v], h[v] = tot;
}
void Merge(int x, int y){
if(S[x].size() < S[y].size()) swap(S[x], S[y]);
for(it = S[y].begin(); it != S[y].end(); ++it) S[x].insert(*it);
}
void dfs(int u, int fa){
Next(i, u){
int v = e[i].v; if(v == fa) continue;
dfs(v, u), Merge(u, v);
}
S[u].insert(w[u]);
it = S[u].lower_bound(w[u]);
if(it != S[u].end() && it != S[u].begin()) S[u].erase(--it);
}
int main(){
n = read();
rep(i, 1, n) w[i] = d[i] = read();
sort(d + 1, d + n + 1);
cnt = unique(d + 1, d + n + 1) - d - 1;
rep(i, 1, n) w[i] = lower_bound(d + 1, d + cnt + 1, w[i]) - d;
rep(i, 2, n) u = read(), add(u, i);
dfs(1, 0);
printf("%d", S[1].size());
return 0;
}
实际上有这样一个套路:
对于这种 \(dp\) 值实际上是一个一次函数或常数函数的 \(dp\) 可以先差分然后再使用 \(set\) 或线段树合并来维护,因为线段树可以支持区间加如果要维护的东西涉及区间修改那么可以使用线段树合并,而类似 AT2347 [ARC070C] NarrowRectangles 可以用 \(set\) 来维护拐点支持函数平移。