bzoj 3626 : [LNOI2014]LCA (树链剖分+线段树)

Description

给出一个n个节点的有根树(编号为0到n-1,根节点为0)。一个点的深度定义为这个节点到根的距离+1。
设dep[i]表示点i的深度,LCA(i,j)表示i与j的最近公共祖先。
有q次询问,每次询问给出l r z,求sigma_{l<=i<=r}dep[LCA(i,z)]。
(即,求在[l,r]区间内的每个节点i与z的最近公共祖先的深度之和)

 

Input

第一行2个整数n q。
接下来n-1行,分别表示点1到点n-1的父节点编号。
接下来q行,每行3个整数l r z。

 

Output

输出q行,每行表示一个询问的答案。每个答案对201314取模输出

 

Sample Input

5 2
0
0
1
1
1 4 3
1 4 2

Sample Output

8
5
 
思路:
表述不清。。。直接贴大佬的题解思路把

清华爷gconeice的题解:

显然,暴力求解的复杂度是无法承受的。
考虑这样的一种暴力,我们把 z 到根上的点全部打标记,对于 l 到 r 之间的点,向上搜索到第一个有标记的点求出它的深度统计答案。观察到,深度其实就是上面有几个已标记了的点(包括自身)。所以,我们不妨把 z 到根的路径上的点全部 +1,对于 l 到 r 之间的点询问他们到根路径上的点权和。仔细观察上面的暴力不难发现,实际上这个操作具有叠加性,且可逆。也就是说我们可以对于 l 到 r 之间的点 i,将 i 到根的路径上的点全部 +1, 转而询问 z 到根的路径上的点(包括自身)的权值和就是这个询问的答案。把询问差分下,也就是用 [1, r] − [1, l − 1] 来计算答案,那么现在我们就有一个明显的解法。从 0 到 n − 1 依次插入点 i,即将 i 到根的路径上的点全部+1。离线询问答案即可。我们现在需要一个数据结构来维护路径加和路径求和,显然树链剖分或LCT 均可以完成这个任务。树链剖分的复杂度为 O((n + q)· log n · log n),LCT的复杂度为 O((n + q)· log n),均可以完成任务。至此,题目已经被我们完美解决。

 

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define mid int m = (l + r) >> 1
const int mod = 201314;
const int M = 1e5 + 10;
struct node{
    int to,next;
}e[M];
int cnt,cnt1,n;
int son[M],siz[M],head[M],fa[M],top[M],dep[M],tid[M];
int sum[M<<2],lazy[M<<2];
void add(int u,int v){
    e[++cnt].to = v;e[cnt].next = head[u];head[u] = cnt;
}

void dfs1(int u,int faz,int deep){
    dep[u] = deep;
    fa[u] = faz;
    siz[u] = 1;
    for(int i = head[u];i;i = e[i].next){
        int v = e[i].to;
        if(v == faz) continue;
        dfs1(v,u,deep+1);
        siz[u] += siz[v];
        if(siz[v] > siz[son[u]]||son[u] == -1)
            son[u] = v;
    }
}

void dfs2(int u,int t){
    top[u] = t;
    tid[u] = ++cnt1;
    if(son[u] == -1) return ;
    dfs2(son[u],t);
    for(int i = head[u];i;i=e[i].next){
        int v = e[i].to;
        if(v != fa[u]&&v != son[u])
            dfs2(v,v);
    }
}

void pushup(int rt){
    sum[rt] = sum[rt<<1] + sum[rt<<1|1];
}

void pushdown(int l,int r,int rt){
    if(lazy[rt]){
        mid;
        sum[rt<<1] += (m-l+1)*lazy[rt];
        sum[rt<<1|1] += (r-m)*lazy[rt];
        lazy[rt<<1] += lazy[rt];
        lazy[rt<<1|1] += lazy[rt];
        lazy[rt] = 0;
    }
}

void build(int l,int r,int rt){
    lazy[rt] = 0;
    if(l==r){
        sum[rt] = lazy[rt] = 0;
        return ;
    }
    mid;
    build(lson); build(rson);
    pushup(rt);
}

void update(int L,int R,int l,int r,int rt){
    if(L <= l&&R >= r){
         sum[rt] += (r-l+1);
         lazy[rt] += 1;
         return ;
    }
    pushdown(l,r,rt);
    mid;
    if(L <= m) update(L,R,lson);
    if(R > m) update(L,R,rson);
    pushup(rt);
}

int query(int L,int R,int l,int r,int rt){
     if(L <= l&&R >= r){
        return sum[rt];
     }
     pushdown(l,r,rt);
     mid;
     int ret = 0;
     if(L <= m) ret += query(L,R,lson);
     if(R > m) ret += query(L,R,rson);
     return ret;
}

void update(int x,int y){
    int fx = top[x],fy = top[y];
    while(fx != fy){
        if(dep[fx] < dep[fy]) swap(x,y),swap(fx,fy);
        update(tid[fx],tid[x],1,n,1);
        x = fa[fx]; fx = top[x];
    }
    if(dep[x] > dep[y]) swap(x,y);
    update(tid[x],tid[y],1,n,1);
}

int solve(int x,int y){
    int fx = top[x],fy = top[y];
    int ans = 0;
    while(fx != fy){
        if(dep[fx] < dep[fy]) swap(x,y),swap(fx,fy);
        ans += query(tid[fx],tid[x],1,n,1);
        x = fa[fx]; fx = top[x];
    }
    if(dep[x] > dep[y]) swap(x,y);
    ans += query(tid[x],tid[y],1,n,1);
    return ans;
}
struct node1{
    int ans1,ans2,z;
}q[M];
struct node2{
    int pos,flag,id;
}a[M];
bool cmp(node2 x,node2 y){
    return x.pos < y.pos;
}

int main()
{
    int m,l,r,x;
    cnt = 0;cnt1 = 0;
    scanf("%d%d",&n,&m);
    for(int i = 1;i < n;i ++){
        scanf("%d",&x);
        add(x+1,i+1);
    }
    memset(son,-1,sizeof(son));
    dfs1(1,0,1); dfs2(1,0);
    build(1,n,1);
    int tot = 0;
    for(int i = 1;i <= m;i ++){
        scanf("%d%d%d",&l,&r,&q[i].z);
        l++;r++;q[i].z++;
        a[++tot].pos=l-1;a[tot].flag=0;a[tot].id=i;
        a[++tot].pos=r;a[tot].flag=1;a[tot].id=i;
    }
    int now = 1;
    sort(a+1,a+tot+1,cmp);
    for(int i = 1;i <= tot;i ++){
        while(now <= a[i].pos){
            update(1,now);
            now++;
        }
        int num = a[i].id;
        if(a[i].flag==0) q[num].ans1 = solve(1,q[num].z);
        else q[num].ans2 = solve(1,q[num].z);
    }
    for(int i = 1;i <= m;i ++){
        int ans = (q[i].ans2-q[i].ans1)%mod;
        printf("%d\n",ans);
    }
    return 0;
}

 

posted @ 2018-11-01 18:57  冥想选手  阅读(143)  评论(0编辑  收藏  举报