Codeforces Round #805 (Div. 3) G. Passable Paths(Lca)
G. Passable Paths
题意 : 给定一颗树,随后有q个询问,每次询问会给出k个点,问这个些点是否在一条链上
分析: 对于一条链上的点设a,b,为链的端点,则任以一个点c,在链上的条件是
\(dist(a,c) + dist(c,b) = dist(a,b)\)
倘若有一个点不满足上式则说明给定路径不是一条链,所以我们枚举k个点看是否符合条件即可
dist 可以通过 lca 来求
ac代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <queue>
#include <map>
#include <vector>
#include <stack>
#include <set>
#include <sstream>
#include <fstream>
#include <cmath>
#include <iomanip>
#include <bitset>
#include <unordered_map>
#include <unordered_set>
#define x first
#define y second
#define ios ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define endl '\n'
#define pb push_back
#define all(x) x.begin(),x.end()
#define all1(x) x.begin()+1,x.end()
#define pi 3.14159265358979323846
using namespace std;
typedef long long LL;
typedef pair<LL,LL> PII;
const int N = 400010,M = 400010,INF = 0x3f3f3f3f,mod = 998244353;
const double INFF = 0x7f7f7f7f7f7f7f7f;
int n,m,k,t;
int h[N],e[N],ne[N],idx;
int depth[N],fa[N][19];
void add(int a,int b)
{
e[idx] = b,ne[idx] = h[a],h[a] = idx ++;
}
void bfs()
{
memset(depth,0x3f,sizeof depth);
depth[0] = 0;
depth[1] = 1;
queue<int> q;
q.push(1);
while(q.size())
{
int t = q.front();
q.pop();
for(int i = h[t];~ i;i = ne[i])
{
int j = e[i];
if(depth[j] > depth[t] + 1)
{
depth[j] = depth[t] + 1;
fa[j][0] = t;
for(int k = 1;k <= 18;k ++)
fa[j][k] = fa[fa[j][k - 1]][k - 1];
q.push(j);
}
}
}
}
int lca(int a,int b)
{
if(depth[a] < depth[b]) swap(a,b);
for(int k = 18;k >= 0;k --)
if(depth[fa[a][k]] >= depth[b])
{
a = fa[a][k];
}
if(a == b) return a;
for(int k = 18;k >= 0;k --)
if(fa[a][k] != fa[b][k])
{
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
int get_dist(int a,int b)
{
int anc = lca(a,b);
return depth[a] + depth[b] - 2 * depth[anc];
}
int main()
{
ios;
memset(h,-1,sizeof h);
cin >> n;
for(int i = 1;i < n;i ++)
{
int a,b;
cin >> a >> b;
add(b,a),add(a,b);
}
bfs();
cin >> m;
while(m --)
{
cin >> k;
vector<int> a(k + 1);
for(int i = 1;i <= k;i ++) cin >> a[i];
if(k == 1 || k == 2)
{
cout << "YES" << endl;
continue;
}
bool success = true;
int x = a[1],y = a[2];
for(int i = 3;i <= k;i ++)
{
int z = a[i];
if(get_dist(x,y) + get_dist(y,z) == get_dist(x,z))
{
y = z;
}
else if(get_dist(z,x) + get_dist(x,y) == get_dist(z,y))
{
x = z;
}
else if(get_dist(x,z) + get_dist(z,y) != get_dist(x,y))
{
success = false;
break;
}
}
if(success) cout << "YES" << endl;
else cout << "NO" << endl;
}
return 0;
}