P4211 [LNOI2014]LCA
分析
本题要计算的就是l~r与z的LCA的深度之和
我们来看看,是否可以将求多个dep转化一下
我们先对dep有一个理解,dep就是从i到root总共有多少点
我们从整体上考虑,发现对于一个询问:l , r , z 来说,所有的 lca 都在 z 到根的路径上。从而有一些点,它们对很多的 lca 的深度都有贡献,而这个贡献等于在这个点下面的 lca 的个数,所以我们可以把每个 lca 到根的路径上的每个点的权值都加一。然后从 z 向上走到根,沿路统计的权值就是答案了。
总结一下,若只有一次询问,则我们想统计答案,只需要将l~r中所有点,从其位置向根节点中的路径的所有点点权+1,最后从z向根节点求一个区间和,即为答案
是时候考虑优化的问题了:我们每次的操作都是从某个点到根的,所以树链剖分+线段树就好了。
但是考虑到每次统计时,不能很好的排除 l ~ r 区间之外的点对 z->根 这条路径的贡献,所以我们每次都要清空线段树。
我们每次清空线段树,然后从 l ~ r 再添加一遍,树剖+线段树的复杂度就是\(O(n * logn * logn)\)的,还要做 q 次,复杂度依然不理想。
看数据范围,\(O(n * logn * logn)\)应该就是正解了,现在要想办法优化掉最后的那个 q 的复杂度。
我们看到区间 l~r ,我们需要考虑的是,如何排除l~r之外区间的影响?
我们联想一下类似于主席树的思路,我们从1开始以此将从该节点到根节点的路径中的所有点全部+1,则对于[l,r]的区间影响,我们可以通过用[1,r]的版本减去[1,l-1]版本的影响
我们可以将询问的区间拆开为两个,将询问离线查询。按照右端点从小到大排序(左端点都是根),然后按从小到大的顺序添加点,每遇到一个询问就查询一次,从而排除掉区间之外的点的影响,也就优化掉一个 q 的复杂度。
具体的实现过程可以看一看代码。
Ac_code
#include<bits/stdc++.h>
using namespace std;
const int N = 5e4 + 10,mod = 201314;
typedef pair<int,int> PII;
typedef long long LL;
struct Node
{
int l,r,sum,tag;
}tr[N<<2];
struct Query
{
int r,z,id;//每一个询问拆成两部分,分别存储一个询问的l-1,r,以及z,询问的编号,以及是左端点还是右端点
bool f;
bool operator<(const Query& W)const
{
return r<W.r;
}
}que[N<<1];
int h[N],e[N],ne[N],idx;
int sz[N],son[N],fa[N],dep[N];
int top[N],id[N],ts;
PII ans[N];
//对于每一个ans的first存储的是对于1~l-1的版本中,从z到根的权值和。
//而每一个ans的second存储的是对于1~r的版本中,从z到根的权值和
//两者相减,则可以排除掉其余区间对z到根节点的权值和的影响。
int n,q;
void add(int a,int b)
{
e[idx] = b,ne[idx] = h[a],h[a] = idx++;
}
void dfs1(int u,int depth)
{
sz[u] = 1,dep[u] = depth++;
for(int i=h[u];~i;i=ne[i])
{
int j = e[i];
dfs1(j,depth+1);
sz[u] += sz[j];
if(sz[j]>sz[son[u]]) son[u] = j;
}
}
void dfs2(int u,int tp)
{
top[u] = tp,id[u] = ++ts;
if(!son[u]) return ;
dfs2(son[u],tp);
for(int i=h[u];~i;i=ne[i])
{
int j = e[i];
if(j==son[u]) continue;
dfs2(j,j);
}
}
void pushup(int u)
{
tr[u].sum = (LL)(tr[u<<1].sum + tr[u<<1|1].sum)%mod;
}
void pushdown(int u)
{
auto &root = tr[u],&left = tr[u<<1],&right = tr[u<<1|1];
if(root.tag)
{
left.tag = (LL)(left.tag + root.tag)%mod;
right.tag = (LL)(right.tag + root.tag)%mod;
left.sum = (LL)(left.sum + 1ll*(left.r - left.l + 1)*root.tag%mod)%mod;
right.sum = (LL)(right.sum + 1ll*(right.r - right.l + 1)*root.tag%mod)%mod;
root.tag = 0;
}
}
void build(int u,int l,int r)
{
tr[u] = {l,r};
if(l==r) return ;
int mid = l + r >> 1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);
}
void modify(int u,int l,int r)
{
if(l<=tr[u].l&&tr[u].r<=r)
{
tr[u].tag ++ ;
tr[u].sum = (LL)(tr[u].sum + tr[u].r - tr[u].l + 1)%mod;
return ;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l<=mid) modify(u<<1,l,r);
if(r>mid) modify(u<<1|1,l,r);
pushup(u);
}
int query(int u,int l,int r)
{
if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sum;
int res = 0;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l<=mid) res = (LL)(res + query(u<<1,l,r))%mod;
if(r>mid) res = (LL)(res + query(u<<1|1,l,r))%mod;
return res;
}
void modify_chain(int x)
{
while(top[x]!=1)
{
modify(1,id[top[x]],id[x]);
x = fa[top[x]];
}
modify(1,1,id[x]);
}
int query_chain(int x)
{
int res = 0;
while(top[x]!=1)
{
res = (LL)(res + query(1,id[top[x]],id[x]))%mod;
x = fa[top[x]];
}
res = (LL)(res + query(1,1,id[x]))%mod;
return res;
}
int main()
{
scanf("%d%d",&n,&q);
memset(h,-1,sizeof h);
for(int i=2;i<=n;i++)
{
cin>>fa[i];fa[i]++;
add(fa[i],i);
}
for(int i=0;i<q;i++)
{
int l,r,z;scanf("%d%d%d",&l,&r,&z);
l++,r++,z++;
que[i*2] = {l-1,z,i,1};
que[i*2+1] = {r,z,i,0};
}
dfs1(1,1);
dfs2(1,1);
build(1,1,n);
sort(que,que+q*2);
int now = 0;
for(int i=0;i<q*2;i++)
{
while(now<que[i].r) modify_chain(++now);//对于每一个点都将其到根的路径中的所有点点权+1
if(que[i].f) ans[que[i].id].first = query_chain(que[i].z);//遇到询问后,记录一下这是哪个询问的左端点版本还是右端点版本
else ans[que[i].id].second = query_chain(que[i].z);
}
for(int i=0;i<q;i++)
printf("%d\n",((ans[i].second-ans[i].first)%mod+mod)%mod);
return 0;
}