BZOJ 3626 [LNOI2014]LCA

传送门

一道比较神奇的题。

树链剖分+奇技淫巧;

神奇地发现,把z到跟的路径上的点值+1,查询一个点到跟的路径和就是它与z的lca的深度。

相对的,把l~r到跟的路径上的点值+1,查询z到跟的路径和就是要的答案。

考虑差分,把一个询问拆成两个,把所有询问排序然后从0~n-1到跟路径上的值+1;

一开始狂WA,发现把线段树区间加的(l-r)*v打成了(qr-ql)*v了。。。

//Twenty
#include<cstdio>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<queue>
#include<vector>
typedef long long LL;
using namespace std;
const int mod=201314;
const int N=200000+299;
int n,m,cnt,l,r,z;
LL ans[N];
struct qs{
    int id,x,f,ts,z;
    qs(){}
    qs(int id,int x,int f,int ts,int z):id(id),x(x),f(f),ts(ts),z(z){}
    friend bool operator <(const qs&a,const qs&b) {
        return a.x<b.x;
    }
}q[N]; 

int ecnt,fir[N],nxt[N],to[N];

void add(int u,int v) {
    nxt[++ecnt]=fir[u]; fir[u]=ecnt; to[ecnt]=v;
}

#define lc x<<1
#define rc x<<1|1
#define mid ((l+r)>>1)
LL sg[N<<2],lz[N<<2];

void down(int x,int l_len,int r_len) {
     if(!lz[x]) return;
     if(lc) (sg[lc]+=l_len*lz[x])%=mod,(lz[lc]+=lz[x])%=mod;
     if(rc) (sg[rc]+=r_len*lz[x])%=mod,(lz[rc]+=lz[x])%=mod;
     lz[x]=0;
}

void add(int x,int l,int r,int ql,int qr,LL v) {
    if(l>=ql&&r<=qr) { (sg[x]+=v*(r-l+1))%=mod; (lz[x]+=v)%=mod; return;}
    down(x,mid-l+1,r-mid);
    if(ql<=mid) add(lc,l,mid,ql,qr,v);
    if(qr>mid) add(rc,mid+1,r,ql,qr,v);
    sg[x]=sg[lc]+sg[rc];
    if(sg[x]>=mod) sg[x]-=mod;
}

LL query(int x,int l,int r,int ql,int qr) {
    if(l>=ql&&r<=qr) return sg[x];
    down(x,mid-l+1,r-mid);
    if(qr<=mid) return query(lc,l,mid,ql,qr);
    if(ql>mid) return query(rc,mid+1,r,ql,qr);
    LL res=query(lc,l,mid,ql,qr)+query(rc,mid+1,r,ql,qr);
    return res>=mod?res-mod:res;
}

int R[N],sz[N];

void DFS(int x) {
    sz[x]=1;
    for(int i=fir[x];i;i=nxt[i]) {
        R[to[i]]=R[x]+1;
        DFS(to[i]);
        sz[x]+=sz[to[i]];
    }
}

int tot,top[N],fa[N],tid[N];

void dfs(int x,int t) {
    top[x]=t;
    tid[x]=++tot;
    int mson=0;
    for(int i=fir[x];i;i=nxt[i])
        if(!mson||sz[to[i]]>=sz[mson]) mson=to[i];
    if(!mson) return;
    dfs(mson,t);
    for(int i=fir[x];i;i=nxt[i])
    if(to[i]!=mson) dfs(to[i],to[i]);
}

void schange(int l,int r,int v) {
    while(top[l]!=top[r]) {
        if(R[top[l]]<R[top[r]]) swap(l,r);
        add(1,1,n,tid[top[l]],tid[l],v);
        l=fa[top[l]];
    }
    if(tid[l]>tid[r]) swap(l,r);
    add(1,1,n,tid[l],tid[r],v);
}

LL squery(int l,int r) {
    LL res=0;
    while(top[l]!=top[r]) {
        if(R[top[l]]<R[top[r]]) swap(l,r);
        res+=query(1,1,n,tid[top[l]],tid[l]);
        if(res>=mod) res-=mod;
        l=fa[top[l]];
    }
    if(tid[l]>tid[r]) swap(l,r);
    res+=query(1,1,n,tid[l],tid[r]);
    if(res>=mod) res-=mod;
    return res;
}

int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++) {
        scanf("%d",&fa[i]);
        add(fa[i],i);
    }
    DFS(0);
    dfs(0,0);
    for(int i=1;i<=m;i++) {
        scanf("%d%d%d",&l,&r,&z);
        if(l) q[++cnt]=qs(i,l-1,-1,0,z);
        q[++cnt]=qs(i,r,1,0,z);
    }
    sort(q+1,q+cnt+1);
    int now=1;
    for(int i=0;i<n;i++) {
        schange(0,i,1);
        while(now<=m*2&&q[now].x<=i) {
            LL res=squery(0,q[now].z);
            ans[q[now].id]+=q[now].f*res;
            now++;
            if(now>2*m) break;
        }
    }
    for(int i=1;i<=m;i++)
        printf("%lld\n",(ans[i]+mod)%mod);
    return 0;
}
View Code

 

posted @ 2017-09-29 08:03  啊宸  阅读(169)  评论(0编辑  收藏  举报