LCA「树链剖分+线段树」
LCA「树链剖分+线段树」
题目描述
给出一个 \(n\) 个节点的有根树(编号为 \(0\) 到 \(n-1\),根节点为 \(0\))。一个点的深度定义为这个节点到根的距离 \(+1\) 。
设 \(dep[i]\) 表示点 \(i\) 的深度,\(LCA(i,j)\) 表示 \(i\) 与 \(j\) 的最近公共祖先。
有 \(q\) 次询问,每次询问给出 \(l,r,z\),求 \(∑_{i=l}^{r}dep[LCA(i,z)]\) 。
(即,求在 \([l,r]\) 区间内的每个节点i与z的最近公共祖先的深度之和)
输入格式
第一行 \(2\) 个整数 \(n,q\)。
接下来 \(n-1\) 行,分别表示点 \(1\) 到点 \(n-1\) 的父节点编号。
接下来 \(q\) 行,每行 \(3\) 个整数 \(l,r,z\)。
输出格式
输出 \(q\) 行,每行表示一个询问的答案。每个答案对 \(201314\) 取模输出
样例
样例输入
5 2
0
0
1
1
1 4 3
1 4 2
样例输出
8
5
数据范围与提示
共 \(10\) 组数据,\(n\) 与 \(q\) 的规模分别为 \(1000,2000,3000,4000,5000,10000,20000,30000,40000,50000\)。
思路分析
- 题目给了深度的定义,其实有一定的提示,所以我们就直接对相应结点到根节点的路径进行处理
- 那么这时 \(LCA\) 的深度就变成了两个点到根节点的路径的重合部分。
- 所以题意就变成了:每次把询问区间 \([l,r]\) 里的点到根节点路径上的点权值加一,最后询问 \(z\) 到根节点的路径上的权值和。这个操作可以用线段树实现
- 由于我们每一次修改和查询都是和根节点有关,所以我们把 \([l,r]\) 通过差分拆成 \([1,l-1]\) 和 \([1,r]\) 最后查询的时候用 \([1,r]\) - \([1,l-1]\) 即可
- 最后为了答案不会互相干扰,所以可以使用离线做法,将每个 \([1,x]\) 区间按 \(x\) 的大小排序即可,另外标记一下是 \(l\) 还是 \(r\) 即可
\(Code\)
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<queue>
#define R register
#define N 200010
#define int long long
using namespace std;
inline int read(){
int x=0,f=1;
char ch=getchar();
while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
const int mod = 201314;
int n,q,dep[N],f[N],siz[N],son[N],dfn[N],cnt,top[N],head[N],ans[N];
int tr[N<<2],tag[N<<2];
struct edge{
int to,next;
}e[N<<1];
struct problem{
int pos,z,id,flag;
problem(){}
problem(int _pos,int _z,int _id,int _flag){pos=_pos,z=_z,id=_id,flag=_flag;}
bool operator <(const problem &a)const{
return pos < a.pos;
}
}ask[N<<1];
int len;
void addedge(int u,int v){
e[++len].to = v;
e[len].next = head[u];
head[u] = len;
}
void dfs(int u,int fa){
dep[u] = dep[fa]+1;
f[u] = fa;
siz[u] = 1;
for(int i = head[u];i;i = e[i].next){
int v = e[i].to;
if(v==fa)continue;
dfs(v,u);
siz[u] += siz[v];
if(siz[v]>siz[son[u]])son[u] = v;
}
}
void dfs2(int u,int tp){
dfn[u] = ++cnt;
top[u] = tp;
if(son[u])dfs2(son[u],tp);
for(int i = head[u];i;i = e[i].next){
int v = e[i].to;
if(v==f[u]||v==son[u])continue;
dfs2(v,v);
}
}
#define ls rt<<1
#define rs rt<<1|1
inline void pushdown(int rt,int l,int r){
if(tag[rt]){
tag[ls] += tag[rt];
tag[rs] += tag[rt];
int mid = (l+r)>>1;
tr[ls] += (mid-l+1)*tag[rt];
tr[rs] += (r-mid)*tag[rt];
tag[rt] = 0;
}
}
void update(int rt,int l,int r,int s,int t){
if(s<=l&&t>=r){
tag[rt] += 1;
tr[rt] += r-l+1;
return;
}
pushdown(rt,l,r);
int mid = (l+r)>>1;
if(s<=mid)update(ls,l,mid,s,t);
if(t>mid)update(rs,mid+1,r,s,t);
tr[rt] = tr[ls]+tr[rs];
}
int query(int rt,int l,int r,int s,int t){
if(s<=l&&t>=r)return tr[rt];
pushdown(rt,l,r);
int mid = (l+r)>>1;
if(t<=mid)return query(ls,l,mid,s,t);
else if(s>mid)return query(rs,mid+1,r,s,t);
else return query(ls,l,mid,s,t)+query(rs,mid+1,r,s,t);
}
void modify(int u,int v){
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]])swap(u,v);
update(1,1,n,dfn[top[u]],dfn[u]);
u = f[top[u]];
}
if(dep[u]>dep[v])swap(u,v);
update(1,1,n,dfn[u],dfn[v]);
}
int getsum(int u,int v){
int res = 0;
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]])swap(u,v);
res = (res+query(1,1,n,dfn[top[u]],dfn[u])+mod)%mod;
u = f[top[u]];
}
if(dep[u]>dep[v])swap(u,v);
res = (res+query(1,1,n,dfn[u],dfn[v]))%mod;
return res;
}
signed main(){
freopen("C.in","r",stdin);
freopen("C.out","w",stdout);
n = read(),q = read();
f[1] = 1;
for(int i = 2;i <= n;i++){
int u = read();u++;
addedge(u,i),addedge(i,u);
}
dfs(1,0);
dfs2(1,0);
int tot = 0;
for(int i = 1;i <= q;i++){
int l,r,z;l = read(),r = read(),z = read();
l++,r++,z++;
ask[tot++] = problem(l-1,z,i,0);
ask[tot++] = problem(r,z,i,1);
}
sort(ask,ask+tot);
int cur = 1;
for(int i = 0;i < tot;i++){
while(cur<=ask[i].pos)modify(1,cur++); //经过的路径进行赋值
if(ask[i].flag)ans[ask[i].id] += getsum(1,ask[i].z); //右端点加
else ans[ask[i].id] -= getsum(1,ask[i].z); //左端点减
ans[ask[i].id] = (ans[ask[i].id]+mod)%mod;
}
for(int i = 1;i <= q;i++){
printf("%lld\n",ans[i]);
}
return 0;
}