Codeforces Round #805 (Div. 3) G. Passable Paths(Lca)

G. Passable Paths

题意 : 给定一颗树,随后有q个询问,每次询问会给出k个点,问这个些点是否在一条链上

分析: 对于一条链上的点设a,b,为链的端点,则任以一个点c,在链上的条件是
\(dist(a,c) + dist(c,b) = dist(a,b)\)
倘若有一个点不满足上式则说明给定路径不是一条链,所以我们枚举k个点看是否符合条件即可
dist 可以通过 lca 来求
ac代码

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <queue>
#include <map>
#include <vector>
#include <stack>
#include <set>
#include <sstream>
#include <fstream> 
#include <cmath>
#include <iomanip>
#include <bitset>
#include <unordered_map>
#include <unordered_set>
#define x first
#define y second
#define ios ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define endl '\n'
#define pb push_back
#define all(x) x.begin(),x.end()
#define all1(x) x.begin()+1,x.end()
#define pi 3.14159265358979323846	
using namespace std;
typedef long long LL;
typedef pair<LL,LL> PII;
 
const int N = 400010,M = 400010,INF = 0x3f3f3f3f,mod = 998244353;
const double INFF = 0x7f7f7f7f7f7f7f7f;
 
int n,m,k,t;
int h[N],e[N],ne[N],idx;
int depth[N],fa[N][19];
void add(int a,int b)
{
	e[idx] = b,ne[idx] = h[a],h[a] = idx ++;
}

void bfs()
{
	memset(depth,0x3f,sizeof depth);
	depth[0] = 0;
	depth[1] = 1;
	queue<int> q;
	q.push(1);
	while(q.size())
	{
		int t = q.front();
		q.pop();
		for(int i = h[t];~ i;i = ne[i])
		{
			int j = e[i];
			if(depth[j] > depth[t] + 1)
			{
				depth[j] = depth[t] + 1;
				fa[j][0] = t;
				for(int k = 1;k <= 18;k ++)
					fa[j][k] = fa[fa[j][k - 1]][k - 1];
				q.push(j);
			}
		}
	}
}

int lca(int a,int b)
{
	if(depth[a] < depth[b]) swap(a,b);

	for(int k = 18;k >= 0;k --)
		if(depth[fa[a][k]] >= depth[b])
		{
			a = fa[a][k];
		}
	if(a == b) return a;
	for(int k = 18;k >= 0;k --)
		if(fa[a][k] != fa[b][k])
		{
			a = fa[a][k];
			b = fa[b][k];
		}
	return fa[a][0];

}

int get_dist(int a,int b)
{
	int anc = lca(a,b);
	return depth[a] + depth[b] - 2 * depth[anc];
}

int main()
{
	ios;
	memset(h,-1,sizeof h);
	cin >> n;
	for(int i = 1;i < n;i ++)
	{
		int a,b;
		cin >> a >> b;
		add(b,a),add(a,b);
	}

	bfs();

	cin >> m;

	while(m --)
	{
		cin >> k;
		vector<int> a(k + 1);
		for(int i = 1;i <= k;i ++) cin >> a[i];
		if(k == 1 || k == 2) 
		{
			cout << "YES" << endl;
			continue;
		}
		bool success = true;

		int x =  a[1],y = a[2];

		for(int i = 3;i <= k;i ++)
		{
			int z = a[i];
			
			if(get_dist(x,y) + get_dist(y,z) == get_dist(x,z))
			{
				y = z;
			}
			else if(get_dist(z,x) + get_dist(x,y) == get_dist(z,y))
			{
				x = z;
			}
			else if(get_dist(x,z) + get_dist(z,y) != get_dist(x,y))
			{
				success = false;
				break;
			} 
		}
		if(success) cout << "YES" << endl;
		else cout << "NO" << endl;

	}
	return 0;
}
posted @ 2022-07-14 11:18  notyour_young  阅读(36)  评论(0编辑  收藏  举报