【bzoj3626】[LNOI2014]LCA 树链剖分+线段树
题目描述
给出一个n个节点的有根树(编号为0到n-1,根节点为0)。一个点的深度定义为这个节点到根的距离+1。
设dep[i]表示点i的深度,LCA(i,j)表示i与j的最近公共祖先。
有q次询问,每次询问给出l r z,求sigma_{l<=i<=r}dep[LCA(i,z)]。
(即,求在[l,r]区间内的每个节点i与z的最近公共祖先的深度之和)
输入
第一行2个整数n q。
接下来n-1行,分别表示点1到点n-1的父节点编号。
接下来q行,每行3个整数l r z。
输出
输出q行,每行表示一个询问的答案。每个答案对201314取模输出
样例输入
5 2
0
0
1
1
1 4 3
1 4 2
样例输出
8
5
题解
树链剖分+线段树
考虑两点LCA的深度,可以看作两个点到根节点的路径交的长度(点的个数)。
而路径交的长度,又可以看作把一条路径上的点权值+1,然后查询另一条路径上的点的权值和。
于是本题转化为:把编号在$[l,r]$内的所有点到根路径上的点权值+1,再查询z到根的点权和。
于是我们可以把问题转化为前缀相减的形式,即求编号在$[1,p]$内的所有点到根路径上的点权值+1,查询z到根的点权和。
将拆成前缀相减后的询问离线,按照$p$排序。按照顺序直接处理对应编号,再查询即可。此时需要支持链上修改、链上查询,使用树链剖分+线段树即可。
时间复杂度$O(n\log^2n)$
#include <cstdio> #include <algorithm> #define N 50010 #define lson l , mid , x << 1 #define rson mid + 1 , r , x << 1 | 1 using namespace std; struct data { int p , z , v , id; data() {} data(int P , int Z , int V , int Id) {p = P , z = Z , v = V , id = Id;} bool operator<(const data &a)const {return p < a.p;} }a[N << 1]; int n , head[N] , to[N] , next[N] , cnt , fa[N] , si[N] , bl[N] , pos[N] , tot , sum[N << 2] , tag[N << 2] , ans[N]; inline void add(int x , int y) { to[++cnt] = y , next[cnt] = head[x] , head[x] = cnt; } void dfs1(int x) { int i; si[x] = 1; for(i = head[x] ; i ; i = next[i]) dfs1(to[i]) , si[x] += si[to[i]]; } void dfs2(int x , int c) { int i , k = n; bl[x] = c , pos[x] = ++tot; for(i = head[x] ; i ; i = next[i]) if(si[to[i]] > si[k]) k = to[i]; if(k != n) { dfs2(k , c); for(i = head[x] ; i ; i = next[i]) if(to[i] != k) dfs2(to[i] , to[i]); } } inline void pushdown(int l , int r , int x) { if(tag[x]) { int mid = (l + r) >> 1; sum[x << 1] += tag[x] * (mid - l + 1) , tag[x << 1] += tag[x]; sum[x << 1 | 1] += tag[x] * (r - mid) , tag[x << 1 | 1] += tag[x]; tag[x] = 0; } } void update(int b , int e , int l , int r , int x) { if(b <= l && r <= e) { sum[x] += r - l + 1 , tag[x] ++ ; return; } pushdown(l , r , x); int mid = (l + r) >> 1; if(b <= mid) update(b , e , lson); if(e > mid) update(b , e , rson); sum[x] = sum[x << 1] + sum[x << 1 | 1]; } int query(int b , int e , int l , int r , int x) { if(b <= l && r <= e) return sum[x]; pushdown(l , r , x); int mid = (l + r) >> 1 , ans = 0; if(b <= mid) ans += query(b , e , lson); if(e > mid) ans += query(b , e , rson); return ans; } void modify(int x) { while(bl[x]) update(pos[bl[x]] , pos[x] , 1 , n , 1) , x = fa[bl[x]]; update(1 , pos[x] , 1 , n , 1); } int solve(int x) { int ans = 0; while(bl[x]) ans += query(pos[bl[x]] , pos[x] , 1 , n , 1) , x = fa[bl[x]]; return ans + query(1 , pos[x] , 1 , n , 1); } int main() { int m , i , l , r , x , h = 0; scanf("%d%d" , &n , &m); for(i = 1 ; i < n ; i ++ ) scanf("%d" , &fa[i]) , add(fa[i] , i); dfs1(0) , dfs2(0 , 0); for(i = 1 ; i <= m ; i ++ ) scanf("%d%d%d" , &l , &r , &x) , a[i] = data(l - 1 , x , -1 , i) , a[i + m] = data(r , x , 1 , i); sort(a + 1 , a + 2 * m + 1); for(i = 1 ; i <= 2 * m ; i ++ ) { while(h <= a[i].p) modify(h++); ans[a[i].id] += a[i].v * solve(a[i].z); } for(i = 1 ; i <= m ; i ++ ) printf("%d\n" , ans[i] % 201314); return 0; }