【洛谷P4211】LCA
题目
题目链接:https://www.luogu.com.cn/problem/P4211
给出一个 \(n\) 个节点的有根树(编号为 \(0\) 到 \(n-1\),根节点为 \(0\))。
一个点的深度定义为这个节点到根的距离 \(+1\)。
设 \(dep[i]\) 表示点i的深度,\(LCA(i,j)\) 表示 \(i\) 与 \(j\) 的最近公共祖先。
有 \(q\) 次询问,每次询问给出 \(l\ r\ z\),求 \(\sum_{i=l}^r dep[LCA(i,z)]\) 。
思路
首先 \(dep[lca(x,y)]\) 等价于把 \(x\) 的所有祖先节点标记为 1,然后求 \(y\) 的祖先节点的权值和。
那么 \(\sum^{r}_{i=l} dep[lca(x,i)]\) 等价于把 \([l,r]\) 的所有点的祖先节点全部加一,求 \(x\) 的祖先节点的权值和。
把询问拆成 \(1\sim r\) 的和减去 \(1\sim l-1\) 的和,然后按照编号从小到大枚举点,树剖 + 线段树将这个点到 1 的路径的点权值全部加一。
对于询问就直接求到 1 的权值和即可。
时间复杂度 \(O(n\log^2 n)\)。
代码
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
const int N=50010,MOD=201314;
int head[N],son[N],fa[N],size[N],id[N],rk[N],top[N];
int n,Q,tot;
vector<int> pos[N];
struct edge
{
int next,to;
}e[N];
struct Query
{
int x,l,r,ans;
}ask[N];
void add(int from,int to)
{
e[++tot].to=to;
e[tot].next=head[from];
head[from]=tot;
}
void dfs1(int x)
{
size[x]++;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
dfs1(v);
size[x]+=size[v];
if (size[v]>size[son[x]]) son[x]=v;
}
}
void dfs2(int x,int tp)
{
id[x]=++tot; rk[tot]=x; top[x]=tp;
if (son[x]) dfs2(son[x],tp);
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=son[x]) dfs2(v,v);
}
}
struct SegTree
{
int l[N*4],r[N*4],len[N*4],sum[N*4],lazy[N*4];
void pushdown(int x)
{
if (lazy[x])
{
sum[x*2]=(sum[x*2]+lazy[x]*len[x*2])%MOD;
sum[x*2+1]=(sum[x*2+1]+lazy[x]*len[x*2+1])%MOD;
lazy[x*2]=(lazy[x*2]+lazy[x])%MOD;
lazy[x*2+1]=(lazy[x*2+1]+lazy[x])%MOD;
lazy[x]=0;
}
}
void build(int x,int ql,int qr)
{
l[x]=ql; r[x]=qr; len[x]=qr-ql+1;
if (ql==qr) return;
int mid=(ql+qr)>>1;
build(x*2,ql,mid);
build(x*2+1,mid+1,qr);
}
void pushup(int x)
{
sum[x]=(sum[x*2]+sum[x*2+1])%MOD;
}
void update(int x,int ql,int qr)
{
if (l[x]==ql && r[x]==qr)
{
sum[x]=(sum[x]+len[x])%MOD; lazy[x]++;
return;
}
pushdown(x);
int mid=(l[x]+r[x])>>1;
if (qr<=mid) update(x*2,ql,qr);
else if (ql>mid) update(x*2+1,ql,qr);
else update(x*2,ql,mid),update(x*2+1,mid+1,qr);
pushup(x);
}
int query(int x,int ql,int qr)
{
if (l[x]==ql && r[x]==qr)
return sum[x];
pushdown(x);
int mid=(l[x]+r[x])>>1;
if (qr<=mid) return query(x*2,ql,qr);
if (ql>mid) return query(x*2+1,ql,qr);
return query(x*2,ql,mid)+query(x*2+1,mid+1,qr);
}
}seg;
void Update(int x)
{
while (x)
{
seg.update(1,id[top[x]],id[x]);
x=fa[top[x]];
}
}
int Query(int x)
{
int ans=0;
while (x)
{
ans=(ans+seg.query(1,id[top[x]],id[x]))%MOD;
x=fa[top[x]];
}
return ans;
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&Q);
for (int i=2,x;i<=n;i++)
{
scanf("%d",&x);
add(x+1,i); fa[i]=x+1;
}
for (int i=1;i<=Q;i++)
{
scanf("%d%d%d",&ask[i].l,&ask[i].r,&ask[i].x);
ask[i].r++; ask[i].x++;
pos[ask[i].l].push_back(i);
pos[ask[i].r].push_back(i);
}
tot=0;
dfs1(1); dfs2(1,1);
seg.build(1,1,n);
for (int i=0;i<=n;i++)
{
Update(i);
for (int j=0;j<pos[i].size();j++)
{
int k=pos[i][j],s=Query(ask[k].x);
if (ask[k].l==i) ask[k].ans-=s;
else ask[k].ans+=s;
}
}
for (int i=1;i<=Q;i++)
printf("%d\n",(ask[i].ans%MOD+MOD)%MOD);
return 0;
}