BZOJ3626[LNOI2014]LCA——树链剖分+线段树

题目描述

给出一个n个节点的有根树(编号为0到n-1,根节点为0)。一个点的深度定义为这个节点到根的距离+1。
设dep[i]表示点i的深度,LCA(i,j)表示i与j的最近公共祖先。
有q次询问,每次询问给出l r z,求sigma_{l<=i<=r}dep[LCA(i,z)]。
(即,求在[l,r]区间内的每个节点i与z的最近公共祖先的深度之和)

输入

第一行2个整数n q。
接下来n-1行,分别表示点1到点n-1的父节点编号。
接下来q行,每行3个整数l r z。

输出

输出q行,每行表示一个询问的答案。每个答案对201314取模输出

样例输入

5 2
0
0
1
1
1 4 3
1 4 2

样例输出

8
5

提示

共5组数据,n与q的规模分别为10000,20000,30000,40000,50000。

  

  两个点a,b的lca的深度就是dep[lca],如果暴力地写这道题就是对于每个x与[l,r]内所有数的lca都求一遍,但可以发现lca还有一种求法:对于i,x两点的lca,可以把i到根节点路径上所有的边权+1(刚开始都是零),只要再求x到根节点上的路径和就是lca的深度。那么对于[l,r]内所有的点和x的lca,只要把每个点到根的路径上边权都+1,然后再求x到根的路径和就好了。这个只要树链剖分加线段树就能维护,每次修改和查询在树上边跳边在线段树中操作就行了。但对于每次询问都要把线段树清空再重新标记,显然还是不行的,因此可以离线来做。我们发现求的东西具有可减性,即求[l,r]与x的lca深度和等于求[1,r]与x的lca深度和-[1,l-1]与x的lca深度和。因此每个询问可以拆成两部分,然后把所有查询排序,按节点标号顺序对到根路径上的边+1,每到一个点处理这个点处对应的查询。注意点的编号从零开始。

#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
int x;
int l,r;
int n,m;
int tot;
int num;
int cnt;
int f[100010];
int d[100010];
int s[100010];
bool g[100010];
int a[1000010];
int to[100010];
ll sum[800010];
ll ans[100010];
int top[100010];
int son[100010];
int size[100010];
int head[100010];
int next[100010];
struct node
{
    int x;
    int l;
    int id;
}q[200010];
bool cmp(node a,node b)
{
    return a.l<b.l;
}
void add(int x,int y)
{
    tot++;
    next[tot]=head[x];
    head[x]=tot;
    to[tot]=y;
}
void dfs(int x)
{
    size[x]=1;
    for(int i=head[x];i;i=next[i])
    {
        d[to[i]]=d[x]+1;
        f[to[i]]=x;
        dfs(to[i]);
        size[x]+=size[to[i]];
        if(size[to[i]]>size[son[x]])
        {
            son[x]=to[i];
        }
    }
}
void dfs2(int x,int tp)
{
    s[x]=++num;
    top[x]=tp;
    if(son[x])
    {
        dfs2(son[x],tp);
    }
    for(int i=head[x];i;i=next[i])
    {
        if(to[i]!=son[x])
        {
            dfs2(to[i],to[i]);
        }
    }
}
void pushup(int rt)
{
    sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
void pushdown(int rt,int l,int r)
{
    if(a[rt])
    {
        int mid=(l+r)>>1;
        a[rt<<1]+=a[rt];
        a[rt<<1|1]+=a[rt];
        sum[rt<<1]+=1ll*a[rt]*(mid-l+1);
        sum[rt<<1|1]+=1ll*a[rt]*(r-mid);
        a[rt]=0;
    }
}
void change(int rt,int l,int r,int L,int R)
{
    if(L<=l&&r<=R)
    {
        a[rt]++;
        sum[rt]+=1ll*(r-l+1);
        return ;
    }
    pushdown(rt,l,r);
    int mid=(l+r)>>1;
    if(L<=mid)
    {
        change(rt<<1,l,mid,L,R);
    }
    if(R>mid)
    {
        change(rt<<1|1,mid+1,r,L,R);
    }
    pushup(rt);
}
ll query(int rt,int l,int r,int L,int R)
{
    if(L<=l&&r<=R)
    {
        return sum[rt];
    }
    pushdown(rt,l,r);
    int mid=(l+r)>>1;
    ll res=0;
    if(L<=mid)
    {
        res+=query(rt<<1,l,mid,L,R);
    }
    if(R>mid)
    {
        res+=query(rt<<1|1,mid+1,r,L,R);
    }
    return res;
}
void updata(int x)
{
    while(top[x]!=1)
    {
        change(1,1,n,s[top[x]],s[x]);
        x=f[top[x]];
    }
    change(1,1,n,1,s[x]);
}
ll downdata(int x)
{
    ll res=0;
    while(top[x]!=1)
    {
        res+=query(1,1,n,s[top[x]],s[x]);
        x=f[top[x]];
    }
    res+=query(1,1,n,1,s[x]);
    return res;
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++)
    {
        scanf("%d",&x);
        add(x+1,i+1);
    }
    dfs(1);
    dfs2(1,1);
    for(int i=1;i<=m;i++)
    {
        scanf("%d%d%d",&l,&r,&x);
        x++;
        l++;
        r++;
        q[++cnt].l=l-1;
        q[cnt].x=x;
        q[cnt].id=i;
        q[++cnt].l=r;
        q[cnt].x=x;
        q[cnt].id=i;
    }
    sort(q+1,q+1+cnt,cmp);
    int now=1;
    for(int i=0;i<=n;i++)
    {
        if(i!=0)
        {
            updata(i);
        }
        while(q[now].l==i&&now<=cnt)
        {
            if(g[q[now].id]==0)
            {
                ans[q[now].id]-=downdata(q[now].x);
                g[q[now].id]=1;
            }
            else
            {
                ans[q[now].id]+=downdata(q[now].x);
            }
            now++;
        }
    }
    for(int i=1;i<=m;i++)
    {
        printf("%lld\n",ans[i]%201314);
    }
}
posted @ 2018-08-31 22:58  The_Virtuoso  阅读(611)  评论(0编辑  收藏  举报