BZOJ5479: tree
Description
给出一棵树,根节点为1
给出两个集合,集合由树上节点组成
从两个集合分别选出一个元素,求其LCA
问LCA的最大深度是多少
Input
第一行给出数据组数T
对于每组数据
第一行给出N,M,代表树的节点个数及询问次数
接下来N-1行,每行两个正整数u,v,表示u,v之间有边
接下来2M行,每两个表示一个询问
询问的第一行,第一行正整数a代表集合A中元素个数
接下来a个正整数,表示集合中的节点
询问的第二行,第一行正整数b代表集合B中元素个数
接下来b个正整数,表示集合中的节点
T<=5,N,M<=100000
sigma(a)+sigma(b)<=200000
a,b<=N
Output
对于每个询问,一行一个整数表示结果
Sample Input
1
7 3
1 2
1 3
3 4
3 5
4 6
4 7
1 6
1 7
2 6 7
1 7
2 5 4
2 3 2
Sample Output
3
4
2
Solution
考虑怎么样才会让lca的深度最大,两个点的必须尽可能相近。
那么有一个直观的贪心想法,处理出整个树的dfs序(最好树剖,因为这样会优先编号重儿子),将b数组按dfs序排序,每次对于a中的元素,在b中二分与它dfs序最接近的两个点,对lca的深度取max。
多组数据,记得清空数组。
#include <bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define il inline
namespace io {
#define in(a) a=read()
#define out(a) write(a)
#define outn(a) out(a),putchar('\n')
#define I_int int
inline I_int read() {
I_int x = 0 , f = 1 ; char c = getchar() ;
while( c < '0' || c > '9' ) { if( c == '-' ) f = -1 ; c = getchar() ; }
while( c >= '0' && c <= '9' ) { x = x * 10 + c - '0' ; c = getchar() ; }
return x * f ;
}
char F[ 200 ] ;
inline void write( I_int x ) {
if( x == 0 ) { putchar( '0' ) ; return ; }
I_int tmp = x > 0 ? x : -x ;
if( x < 0 ) putchar( '-' ) ;
int cnt = 0 ;
while( tmp > 0 ) {
F[ cnt ++ ] = tmp % 10 + '0' ;
tmp /= 10 ;
}
while( cnt > 0 ) putchar( F[ -- cnt ] ) ;
}
#undef I_int
}
using namespace io ;
using namespace std ;
#define N 200010
int T;
int n, m, head[N], cnt;
struct edge {
int to, nxt;
} e[N<<1];
int fa[N], dep[N], siz[N], top[N], id[N];
int x, y;
struct Node {
int dfn, val;
}a[N], b[N];
void dfs1(int u) {
siz[u] = 1;
for(int i = head[u]; i; i = e[i].nxt) {
if(e[i].to == fa[u]) continue;
fa[e[i].to] = u;
dep[e[i].to] = dep[u] + 1;
dfs1(e[i].to);
siz[u] += siz[e[i].to];
}
}
int tim = 0;
void dfs2(int u, int topf) {
id[u] = ++tim;
top[u] = topf;
int k = 0;
for(int i = head[u]; i; i = e[i].nxt) {
if(e[i].to == fa[u]) continue;
if(siz[e[i].to] > siz[k]) k = e[i].to;
}
if(!k) return;
dfs2(k, topf);
for(int i = head[u]; i; i = e[i].nxt) {
if(e[i].to == fa[u] || e[i].to == k) continue;
dfs2(e[i].to, e[i].to);
}
}
int lca(int x, int y) {
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
return x;
}
void init() {
cnt = 0; tim = 0;
memset(head, 0, sizeof(head));
memset(dep, 0, sizeof(dep));
memset(id, 0, sizeof(id));
memset(top, 0, sizeof(top));
memset(siz, 0, sizeof(siz));
memset(fa, 0, sizeof(fa));
memset(a, 0, sizeof(a));
memset(b, 0, sizeof(b));
}
void ins(int u, int v) {
e[++cnt].to = v;
e[cnt].nxt = head[u];
head[u] = cnt;
}
int find(int t) {
int l = 1, r = y, ans = y;
while(l <= r) {
int mid = (l + r) >> 1;
if(b[mid].dfn >= t) ans = mid, r = mid - 1;
else l = mid + 1;
}
return ans;
}
bool operator < (Node a, Node b) {
return a.dfn < b.dfn;
}
int main() {
T = read();
while(T--) {
n = read(), m = read();
init();
for(int i = 1; i < n; ++i) {
int u = read(), v = read();
ins(u, v); ins(v, u);
}
dep[1] = 1;
dfs1(1); dfs2(1, 1);
while(m--) {
int ans = 0;
x = read();
for(int i = 1; i <= x; ++i) a[i].val = read(), a[i].dfn = id[a[i].val];
y = read();
for(int i = 1; i <= y; ++i) b[i].val = read(), b[i].dfn = id[b[i].val];
sort(b + 1, b + y + 1);
for(int i = 1; i <= x; ++i) {
int t = find(a[i].dfn);
ans = max(ans, dep[lca(b[t].val, a[i].val)]);
ans = max(ans, dep[lca(b[t - 1].val, a[i].val)]);
}
outn(ans);
}
}
}