Loading

"蔚来杯"2022牛客暑期多校训练营3 A. Ancestor(LCA)

题目大意是给定两棵节点数相同的树,每个点有一个权值。现在给出k个关键点编号,问有多少个关键点编号,将其删除后在两棵树上分别对剩下的关键点编号对应的点求LCA得到的两个祖先中第一棵树的祖先权值大于第二棵树的祖先权值。

首先可以扫一遍所有的关键点,如果这个点满足条件就更新答案。那么问题就转化成了删除一个点,怎么求剩下两边的点的LCA。注意到LCA计算满足交换律和结合律,因此可以维护每棵树对应关键点的前缀LCA(设为LCA_pre)和后缀LCA(设为LCA_suf),如果删去点i,那么剩下的点的LCA就是lca(LCA_pre[i - 1], LCA_suf[i + 1])。单独判断一下边界即可。

#include <bits/stdc++.h>
#define N 200005
using namespace std;
int n, k, x[N], a[N], b[N];
int heada[N], vera[2 * N], Nexta[2 * N], tota = 0;
int headb[N], verb[2 * N], Nextb[2 * N], totb = 0;
int lca_a1[N], lca_a2[N], lca_b1[N], lca_b2[N];
int fa[N][20], da[N];
queue<int> qa;
int fb[N][20], db[N];
queue<int> qb;
int t;
void adda(int x, int y) {
	vera[++tota] = y, Nexta[tota] = heada[x], heada[x] = tota;
}
void addb(int x, int y) {
	verb[++totb] = y, Nextb[totb] = headb[x], headb[x] = totb;
}
void bfsa() {
	qa.push(1); da[1] = 1;
	while(qa.size()) {
		int x = qa.front(); qa.pop();
		for(int i = heada[x]; i; i = Nexta[i]) {
			int y = vera[i];
			if(da[y]) continue;
			da[y] = da[x] + 1;
			fa[y][0] = x;
			for(int j = 1; j <= t; j++) {
				fa[y][j] = fa[fa[y][j - 1]][j - 1];
			}
			qa.push(y);
		}
	}
}	
void bfsb() {
	qb.push(1); db[1] = 1;
	while(qb.size()) {
		int x = qb.front(); qb.pop();
		for(int i = headb[x]; i; i = Nextb[i]) {
			int y = verb[i];
			if(db[y]) continue;
			db[y] = db[x] + 1;
			fb[y][0] = x;
			for(int j = 1; j <= t; j++) {
				fb[y][j] = fb[fb[y][j - 1]][j - 1];
			}
			qb.push(y);
		}
	}
}
int lca_a(int x, int y) {
	if(da[x] > da[y]) swap(x, y);
	for(int i = t; i >= 0; i--) {
		if(da[fa[y][i]] >= da[x]) y = fa[y][i];
	}
	if(x == y) return x;
	for(int i = t; i >= 0; i--) {
		if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
	}
	return fa[x][0];
}
int lca_b(int x, int y) {
	if(db[x] > db[y]) swap(x, y);
	for(int i = t; i >= 0; i--) {
		if(db[fb[y][i]] >= db[x]) y = fb[y][i];
	}
	if(x == y) return x;
	for(int i = t; i >= 0; i--) {
		if(fb[x][i] != fb[y][i]) x = fb[x][i], y = fb[y][i];
	}
	return fb[x][0];
}
void solve() {
	cin >> n >> k;
	t = (int)(log(n) / log(2)) + 1;
	for(int i = 1; i <= k; i++) {
		cin >> x[i];
	}
	for(int i = 1; i <= n; i++) {
		cin >> a[i];
	}
	for(int i = 2; i <= n; i++) {
		int pa;
		cin >> pa;
		adda(pa, i);
		adda(i, pa);
	}
	for(int i = 1; i <= n; i++) {
		cin >> b[i];
	}
	for(int i = 2; i <= n; i++) {
		int pb;
		cin >> pb;
		addb(pb, i);
		addb(i, pb);
	}
	bfsa();
	bfsb();
	lca_a1[1] = x[1];
	for(int i = 2; i <= k; i++) {
		lca_a1[i] = lca_a(lca_a1[i - 1], x[i]);
	}
	lca_a2[k] = x[k];
	for(int i = k - 1; i >= 1; i--) {
		lca_a2[i] = lca_a(lca_a2[i + 1], x[i]);
	}
	lca_b1[1] = x[1];
	for(int i = 2; i <= k; i++) {
		lca_b1[i] = lca_b(lca_b1[i - 1], x[i]);
	}
	lca_b2[k] = x[k];
	for(int i = k - 1; i >= 1; i--) {
		lca_b2[i] = lca_b(lca_b2[i + 1], x[i]);
	}
	int ans = 0;
	for(int i = 1; i <= k; i++) {
		if(i == 1) {
			if(a[lca_a2[2]] > b[lca_b2[2]]) ans++;
		} else if(i == k) {
			if(a[lca_a1[k - 1]] > b[lca_b1[k - 1]]) ans++;
		} else {
			if(a[lca_a(lca_a1[i - 1], lca_a2[i + 1])] > b[lca_b(lca_b1[i - 1], lca_b2[i + 1])]) ans++;
		}
	}
	cout << ans << endl;
}

int main() {
	solve();
	return 0;
}
posted @ 2022-11-16 10:56  脂环  阅读(34)  评论(0编辑  收藏  举报