【树链剖分/倍增模板】【洛谷】3398:仓鼠找sugar
题目描述
小仓鼠的和他的基(mei)友(zi)sugar住在地下洞穴中,每个节点的编号为1~n。地下洞穴是一个树形结构。这一天小仓鼠打算从从他的卧室(a)到餐厅(b),而他的基友同时要从他的卧室(c)到图书馆(d)。他们都会走最短路径。现在小仓鼠希望知道,有没有可能在某个地方,可以碰到他的基友?
小仓鼠那么弱,还要天天被zzq大爷虐,请你快来救救他吧!
输入输出格式
输入格式:
第一行两个正整数n和q,表示这棵树节点的个数和询问的个数。
接下来n-1行,每行两个正整数u和v,表示节点u到节点v之间有一条边。
接下来q行,每行四个正整数a、b、c和d,表示节点编号,也就是一次询问,其意义如上。
输出格式:
对于每个询问,如果有公共点,输出大写字母“Y”;否则输出“N”。
输入输出样例
说明
__本题时限1s,内存限制128M,因新评测机速度较为接近NOIP评测机速度,请注意常数问题带来的影响。__
20%的数据 n<=200,q<=200
40%的数据 n<=2000,q<=2000
70%的数据 n<=50000,q<=50000
100%的数据 n<=100000,q<=100000
为什么要专门把这道题拿出来写?因为这道题有多种方法,链剖和倍增可以刚好一个对应一个方法!
就拿来写了。
这道题实际上就是判断树上两条路径是否有交点。第一次想到的是将其中一条链染色,在另外一条链上查询有没有被染色,实际上就是链剖+线段树区间修改/查询了。这个既好想又好写,一次a。
#include<iostream> #include<cstdio> using namespace std; int n, q; struct Node { int v, nex; Node ( int v = 0, int nex = 0 ) : v ( v ), nex ( nex ) { } } Edge[200005]; int h[100005], stot; void add ( int u, int v ) { Edge[++stot] = Node ( v, h[u] ); h[u] = stot; } int fa[100005], siz[100005], son[100005], dep[100005]; void dfs1 ( int u, int f ) { fa[u] = f; siz[u] = 1; dep[u] = dep[f] + 1; for ( int i = h[u]; i; i = Edge[i].nex ) { int v = Edge[i].v; if ( v == f ) continue; dfs1 ( v, u ); siz[u] += siz[v]; if ( siz[v] > siz[son[u]] ) son[u] = v; } } int top[100005], in[100005], ti; void dfs2 ( int u, int t ) { top[u] = t; in[u] = ++ ti; if ( son[u] ) dfs2 ( son[u], t ); for ( int i = h[u]; i; i = Edge[i].nex ) { int v = Edge[i].v; if ( v == fa[u] || v == son[u] ) continue; dfs2 ( v, v ); } } int TR[400004], tag[400005]; void update ( int nd ) { TR[nd] = TR[nd << 1] + TR[nd << 1 | 1]; } void push_down ( int nd, int l, int r ) { if ( tag[nd] ) { int mid = ( l + r ) >> 1; tag[nd << 1] += tag[nd]; tag[nd << 1 | 1] += tag[nd]; TR[nd << 1] += tag[nd] * ( mid - l + 1 ); TR[nd << 1 | 1] += tag[nd] * ( r - mid ); tag[nd] = 0; } } void add ( int nd, int l, int r, int L, int R, int d ) { if ( l >= L && r <= R ) { TR[nd] += d * ( r - l + 1 ); tag[nd] += d; return ; } push_down ( nd, l, r ); int mid = ( l + r ) >> 1; if ( L <= mid ) add ( nd << 1, l, mid, L, R, d ); if ( R > mid ) add ( nd << 1 | 1, mid + 1, r, L, R, d ); update ( nd ); } void add ( int u, int v, int d ) { while ( top[u] != top[v] ) { if ( dep[top[u]] < dep[top[v]] ) swap ( u, v ); add ( 1, 1, n, in[top[u]], in[u], d ); u = fa[top[u]]; } if ( dep[u] < dep[v] ) swap ( u, v ); add ( 1, 1, n, in[v], in[u], d ); } int query ( int nd, int l, int r, int L, int R ) { if ( l >= L && r <= R ) return TR[nd]; push_down ( nd, l, r ); int ans = 0; int mid = ( l + r ) >> 1; if ( L <= mid ) ans += query ( nd << 1, l, mid, L, R ); if ( R > mid ) ans += query ( nd << 1 | 1, mid + 1, r, L, R ); return ans; } int query ( int u, int v ) { int ans = 0; while ( top[u] != top[v] ) { if ( dep[top[u]] < dep[top[v]] ) swap ( u, v ); ans += query ( 1, 1, n, in[top[u]], in[u] ); u = fa[top[u]]; } if ( dep[u] < dep[v] ) swap ( u, v ); ans += query ( 1, 1, n, in[v], in[u] ); return ans; } int main ( ) { scanf ( "%d%d", &n, &q ); for ( int i = 1; i < n; i ++ ) { int u, v; scanf ( "%d%d", &u, &v ); add ( u, v ); add ( v, u ); } dfs1 ( 1, 0 ); dfs2 ( 1, 0 ); for ( int i = 1; i <= q; i ++ ) { int a, b, c, d; scanf ( "%d%d%d%d", &a, &b, &c, &d ); add ( a, b, 1 ); int tmp = query ( c, d ); if ( tmp ) printf ( "Y\n" ); else printf ( "N\n" ); add ( a, b, -1 ); } return 0; }
然而这道题不需要数据结构维护也可以做到,只需要求$LCA$就可以。
我们可以发现两条链相交的性质:一条链的两端点的$LCA$必定在另一条链上(然而并不会证明),所以需要判断的就是一个点是否在一条链上。
而判断成立的条件有:
1、$dep[x]>=dep[LCA(s,t)]$
2、$LCA(x,s)==x||LCA(x,t)==x$
所以每次取出深度更深的$LCA$和另一条路径的两个端点来判断即可。
#include<iostream> #include<cstdio> using namespace std; const int P = 20; int n, q; struct Node { int v, nex; Node ( int v = 0, int nex = 0 ) : v ( v ), nex ( nex ) { } } Edge[200005]; int h[100005], stot; void add ( int u, int v ) { Edge[++stot] = Node ( v, h[u] ); h[u] = stot; } int jum[100005][P+1], dep[100005]; void dfs ( int u, int f ) { jum[u][0] = f; for ( int i = 1; i <= P; i ++ ) jum[u][i] = jum[jum[u][i-1]][i-1]; for ( int i = h[u]; i; i = Edge[i].nex ) { int v = Edge[i].v; if ( v == f ) continue; dep[v] = dep[u] + 1; dfs ( v, u ); } } int LCA ( int u, int v ) { if ( dep[u] < dep[v] ) swap ( u, v ); int t = dep[u] - dep[v]; for ( int p = 0; t; t >>= 1, p ++ ) if ( t & 1 ) u = jum[u][p]; if ( u == v ) return u; for ( int p = P; p >= 0; p -- ) if ( jum[u][p] != jum[v][p] ) u = jum[u][p], v = jum[v][p]; return jum[u][0]; } int main ( ) { scanf ( "%d%d", &n, &q ); for ( int i = 1; i < n; i ++ ) { int u, v; scanf ( "%d%d", &u, &v ); add ( u, v ); add ( v, u ); } dfs ( 1, 0 ); for ( int i = 1; i <= q; i ++ ) { int a, b, c, d; scanf ( "%d%d%d%d", &a, &b, &c, &d ); int S = LCA ( a, b ), T = LCA ( c, d ); if ( dep[S] < dep[T] ) { swap ( S, T ); swap ( a, c ); swap ( b, d ); } if ( LCA ( S, c ) == S || LCA ( S, d ) == S ) cout << "Y" << endl; else cout << "N" << endl; } return 0; }