BZOJ3626LCA(树剖+线段树+LCA+差分)

题面

给定一棵树,有q次询问,每次询问都是l,r,p,每次询问i=lrdeep[LCA(i,p)].

数据范围达到了上万,显然每次询问要么log地询问,要么O(1)询问,log级的显然是比较难处理的,于是我们想到把询问离线,发现这东西满足差分的性质,即处理出[1,l1][1,r]然后两者相减就是答案.想到这个之后我们考虑怎么求对于每个询问的答案.我们这么想,我们从1用一个指针往后扫,发现遇见这个点是询问中的点(l-1 or r)然后我们对这个询问的p处理下,算出答案.关键是这个答案要怎么算?这里要用到一个非常巧妙的东西,我们从1往后扫的同时,把这个点到根节点上所有点都打上+1的标记,然后我们有一个p,实际上答案就是求p到根路径上标记之和,在纸上画一下就好了.这样每次(log2n)2的复杂度,总复杂度为O(n(log2n)2).

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<cstdlib>
#include<algorithm>

using namespace std;

typedef long long LL;
#define REP(i,a,b) for(register int i = (a),i##_end_ = (b); i <= i##_end_; ++i)
#define DREP(i,a,b) for(register int i = (a),i##_end_ = (b);i >= i##_end_; --i)
#define mem(a,b) memset(a,b,sizeof(a));

int read()
{
    register int flag = 1,sum = 0;char c = getchar();
    while(!isdigit(c)) { if(c == '-')flag = -1; c = getchar(); }
    while(isdigit(c)) { sum = sum * 10 + c - '0'; c = getchar(); }
    return sum * flag;
}

const int maxn = 400000+10;
const int mod = 201314;
int n,m;
int be[maxn],ne[maxn],to[maxn],e;

void add(int x,int y)
{
    to[++e] = y;ne[e] = be[x];be[x] = e;
}

int dep[maxn],size[maxn],son[maxn],fa[maxn];

void dfs1(int x,int dp,int faa)
{
    dep[x] = dp;size[x] = 1;
    for(int i = be[x]; i; i = ne[i])
    {
        int v = to[i];
        if(v != faa)
        {
            dfs1(v,dp+1,x);fa[v] = x;
            size[x] += size[v];
            if(!son[x] || size[son[x]] < size[v])son[x] = v;
        }
    }
}

int top[maxn],tid[maxn],cnt,vis[maxn];

void dfs2(int x,int tp)
{
    vis[x] = 1;top[x] = tp;tid[x] = ++cnt;
    if(!son[x])return ;
    dfs2(son[x],tp);
    for(int i = be[x]; i; i = ne[i])
    {
        int v = to[i];
        if(!vis[v])dfs2(v,v);
    }
}


struct T
{
    int id,flag,z,p;
    LL ans;
    bool operator < (const T&u)const
    {
        return z < u.z;
    }
}q[maxn];
int num;

struct node
{
    int l,r,ld,rd,p;
}ask[maxn];

LL tr[maxn<<2],tag[maxn<<2];

void pushdown(int h,int l,int r)
{
    int mid = (l + r) >> 1;
    tag[h<<1] += tag[h];
    tag[h<<1|1] += tag[h];
    (tr[h<<1] += (mid - l + 1) * tag[h])%=mod;
    (tr[h<<1|1] += (r - mid) * tag[h])%=mod;
    tag[h] = 0;
}

void updata(int h,int l,int r,int q,int w)
{
    if(q <= l &&r <= w)
    {
        (tr[h] += w - q + 1)%=mod;
        tag[h]++;return ;
    }
    int mid = (l + r) >> 1;
    if(tag[h])pushdown(h,l,r);
    if(w <= mid)updata(h<<1,l,mid,q,w);
    else if(q > mid)updata(h<<1|1,mid+1,r,q,w);
    else updata(h<<1,l,mid,q,w),updata(h<<1|1,mid+1,r,q,w);
    tr[h] = tr[h<<1] + tr[h<<1|1];
}

LL query(int h,int l,int r,int q,int w)
{
    if(q <= l && r <= w)return tr[h];
    if(tag[h])pushdown(h,l,r);
    int mid = (l + r) >> 1;
    if(w <= mid)return query(h<<1,l,mid,q,w);
    else if(q > mid)return query(h<<1|1,mid+1,r,q,w);
    else return query(h<<1,l,mid,q,w)+query(h<<1|1,mid+1,r,q,w);
}

void add(int x)
{
    int y = 0; 
    while(top[x] != top[y])
    {
        if(dep[top[x]] < dep[top[y]])swap(x,y);
        updata(1,1,cnt,tid[top[x]],tid[x]);
        x = fa[top[x]];
    }
    updata(1,1,cnt,tid[0],tid[x]);
}

void solve(int id)
{
    int x = q[id].p,y = 0;
    LL res = 0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]] < dep[top[y]])swap(x,y);
        (res += query(1,1,cnt,tid[top[x]],tid[x]))%=mod;
        x = fa[top[x]];
    }
    (res += query(1,1,cnt,tid[0],tid[x]))%=mod;
    q[id].ans = res;
}

int main()
{
    freopen("1.in","r",stdin);
    freopen("1.out","w",stdout);
    n = read();m = read();
    REP(i,1,n-1)
    {
        int x = read();
        add(x,i);add(i,x);
    }//slpf
    dfs1(0,0,-1);dfs2(0,0);
    //sort,prehandle query
    REP(i,1,m)
    {
        ask[i].l = read(),ask[i].r = read(),ask[i].p = read();
        q[++num] = (T){i,0,ask[i].l-1,ask[i].p};
        q[++num] = (T){i,1,ask[i].r,ask[i].p};
    }
    sort(q+1,q+1+num);  
    REP(i,1,num)
    {
        if(!q[i].flag)ask[q[i].id].ld = i;
        else ask[q[i].id].rd = i;
    }
    int ths = 1,pt = -1;
    for(;pt < n && ths <= num;++ths)
    {
        while(pt < q[ths].z)
            ++pt,add(pt);
        solve(ths);
    }
    REP(i,1,m)
        cout<<(q[ask[i].rd].ans-q[ask[i].ld].ans+mod)%mod<<endl;
}
posted @ 2017-09-08 18:47  Drinkwater_cnyali  阅读(166)  评论(0编辑  收藏  举报