bzoj 2588 树上主席树

主席树上树,对于每个节点,继承其父亲的,最后跑f[x]+f[y]-f[lca]-f[fa[lca]]

去重竟然要减一,我竟然不知道??

#include<cstdio>
#include<cstring>
#include<iostream>
#include<cmath>
#include<algorithm>
#define N 100005
using namespace std;
int e=1,head[N];
struct edge{
    int u,v,next;
}ed[2*N];
void add(int u,int v){
    ed[e].u=u;ed[e].v=v;
    ed[e].next=head[u];
    head[u]=e++;
}
 
int root[2*N],sum[80*N],lon[80*N],ron[80*N],sz;
int dep[2*N],fa[2*N][18],n,m,v[2*N],num[2*N],num_cnt;
 
void print(int rt,int l,int r){
    if(!rt) return;
    printf("%d  %d  %d  %d  %d  %d\n",rt,l,r,lon[rt],ron[rt],sum[rt]);
    int mid=(l+r)>>1;
    print(lon[rt],l,mid);
    print(ron[rt],mid+1,r);
}
void update(int p,int &rt,int l,int r,int x){
    rt=++sz;
    sum[rt]=sum[p]+1;
    if(l==r) return;
    lon[rt]=lon[p]; ron[rt]=ron[p];
    int mid=(l+r)>>1;
    if(x<=mid) update(lon[p],lon[rt],l,mid,x);
    else update(ron[p],ron[rt],mid+1,r,x);
}
 
void dfs(int x){
    for(int i=1;i<=17;i++)
        if((1<<i)<=dep[x])
            fa[x][i]=fa[fa[x][i-1]][i-1];
        else break;
    update(root[fa[x][0]],root[x],1,num_cnt,v[x]);
    for(int i=head[x];i;i=ed[i].next){
        int v=ed[i].v;
        if(v==fa[x][0]) continue;
        dep[v]=dep[x]+1; fa[v][0]=x;
        dfs(v);
    }
}
int lca(int x,int y){
    if(dep[x]<dep[y])swap(x,y);
    int t=dep[x]-dep[y];
    for(int i=17;~i;i--)
        if(t&(1<<i))
            x=fa[x][i];
    if(x==y)return x;
    for(int i=17;~i;i--)
        if(fa[x][i]!=fa[y][i])
            x=fa[x][i],y=fa[y][i];
    return fa[x][0];        
}
 
int query(int x,int y,int k){
    int ca=lca(x,y);
    int a=root[x],b=root[y],c=root[ca],d=root[fa[ca][0]];
    int l=1,r=num_cnt;
    while(l<r){
        int mid=(l+r)/2;
        int tmp=sum[lon[a]]+sum[lon[b]]-sum[lon[c]]-sum[lon[d]];
        if(tmp>=k){r=mid;a=lon[a];b=lon[b];c=lon[c];d=lon[d];/*printf("666\n");*/}
        else{k-=tmp;l=mid+1;a=ron[a];b=ron[b];c=ron[c];d=ron[d];}
    }
    //printf("l==%d\n",l);
    return num[l];
}
int main()
{
    int U,V,kth;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++){
        scanf("%d",&v[i]);
        num[i]=v[i];
    }
    sort(num+1,num+n+1);
    num_cnt=unique(num+1,num+n+1)-num-1;
    for(int i=1;i<=n;i++)
        v[i]=lower_bound(num+1,num+num_cnt+1,v[i])-num;
    for(int i=1;i<n;i++){
        scanf("%d%d",&U,&V);
        add(U,V); add(V,U);
    }
    dep[0]=-1;
    dfs(1);
    int ans=0;
    for(int i=1;i<=m;i++){
        scanf("%d%d%d",&U,&V,&kth);
        U^=ans;
        ans=query(U,V,kth);
        printf("%d",ans);
        if(i<m) printf("\n");
    }
    return 0;
}


posted @ 2017-08-03 10:08  Ren_Ivan  阅读(181)  评论(0编辑  收藏  举报