"蔚来杯"2022牛客暑期多校训练营3 A. Ancestor(LCA)
题目大意是给定两棵节点数相同的树,每个点有一个权值。现在给出k个关键点编号,问有多少个关键点编号,将其删除后在两棵树上分别对剩下的关键点编号对应的点求LCA得到的两个祖先中第一棵树的祖先权值大于第二棵树的祖先权值。
首先可以扫一遍所有的关键点,如果这个点满足条件就更新答案。那么问题就转化成了删除一个点,怎么求剩下两边的点的LCA。注意到LCA计算满足交换律和结合律,因此可以维护每棵树对应关键点的前缀LCA(设为LCA_pre)和后缀LCA(设为LCA_suf),如果删去点i,那么剩下的点的LCA就是lca(LCA_pre[i - 1], LCA_suf[i + 1])。单独判断一下边界即可。
#include <bits/stdc++.h>
#define N 200005
using namespace std;
int n, k, x[N], a[N], b[N];
int heada[N], vera[2 * N], Nexta[2 * N], tota = 0;
int headb[N], verb[2 * N], Nextb[2 * N], totb = 0;
int lca_a1[N], lca_a2[N], lca_b1[N], lca_b2[N];
int fa[N][20], da[N];
queue<int> qa;
int fb[N][20], db[N];
queue<int> qb;
int t;
void adda(int x, int y) {
vera[++tota] = y, Nexta[tota] = heada[x], heada[x] = tota;
}
void addb(int x, int y) {
verb[++totb] = y, Nextb[totb] = headb[x], headb[x] = totb;
}
void bfsa() {
qa.push(1); da[1] = 1;
while(qa.size()) {
int x = qa.front(); qa.pop();
for(int i = heada[x]; i; i = Nexta[i]) {
int y = vera[i];
if(da[y]) continue;
da[y] = da[x] + 1;
fa[y][0] = x;
for(int j = 1; j <= t; j++) {
fa[y][j] = fa[fa[y][j - 1]][j - 1];
}
qa.push(y);
}
}
}
void bfsb() {
qb.push(1); db[1] = 1;
while(qb.size()) {
int x = qb.front(); qb.pop();
for(int i = headb[x]; i; i = Nextb[i]) {
int y = verb[i];
if(db[y]) continue;
db[y] = db[x] + 1;
fb[y][0] = x;
for(int j = 1; j <= t; j++) {
fb[y][j] = fb[fb[y][j - 1]][j - 1];
}
qb.push(y);
}
}
}
int lca_a(int x, int y) {
if(da[x] > da[y]) swap(x, y);
for(int i = t; i >= 0; i--) {
if(da[fa[y][i]] >= da[x]) y = fa[y][i];
}
if(x == y) return x;
for(int i = t; i >= 0; i--) {
if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
}
return fa[x][0];
}
int lca_b(int x, int y) {
if(db[x] > db[y]) swap(x, y);
for(int i = t; i >= 0; i--) {
if(db[fb[y][i]] >= db[x]) y = fb[y][i];
}
if(x == y) return x;
for(int i = t; i >= 0; i--) {
if(fb[x][i] != fb[y][i]) x = fb[x][i], y = fb[y][i];
}
return fb[x][0];
}
void solve() {
cin >> n >> k;
t = (int)(log(n) / log(2)) + 1;
for(int i = 1; i <= k; i++) {
cin >> x[i];
}
for(int i = 1; i <= n; i++) {
cin >> a[i];
}
for(int i = 2; i <= n; i++) {
int pa;
cin >> pa;
adda(pa, i);
adda(i, pa);
}
for(int i = 1; i <= n; i++) {
cin >> b[i];
}
for(int i = 2; i <= n; i++) {
int pb;
cin >> pb;
addb(pb, i);
addb(i, pb);
}
bfsa();
bfsb();
lca_a1[1] = x[1];
for(int i = 2; i <= k; i++) {
lca_a1[i] = lca_a(lca_a1[i - 1], x[i]);
}
lca_a2[k] = x[k];
for(int i = k - 1; i >= 1; i--) {
lca_a2[i] = lca_a(lca_a2[i + 1], x[i]);
}
lca_b1[1] = x[1];
for(int i = 2; i <= k; i++) {
lca_b1[i] = lca_b(lca_b1[i - 1], x[i]);
}
lca_b2[k] = x[k];
for(int i = k - 1; i >= 1; i--) {
lca_b2[i] = lca_b(lca_b2[i + 1], x[i]);
}
int ans = 0;
for(int i = 1; i <= k; i++) {
if(i == 1) {
if(a[lca_a2[2]] > b[lca_b2[2]]) ans++;
} else if(i == k) {
if(a[lca_a1[k - 1]] > b[lca_b1[k - 1]]) ans++;
} else {
if(a[lca_a(lca_a1[i - 1], lca_a2[i + 1])] > b[lca_b(lca_b1[i - 1], lca_b2[i + 1])]) ans++;
}
}
cout << ans << endl;
}
int main() {
solve();
return 0;
}