BZOJ2588: Spoj 10628. Count on a tree

题目:http://www.lydsy.com/JudgeOnline/problem.php?id=2588

lca+可持久化线段树

在树上建一棵可持久化线段树就可以了。

#include<cstring>
#include<iostream>
#include<cstdio>
#include<algorithm>
#define rep(i,l,r) for (int i=l;i<=r;i++)
#define down(i,l,r) for (int i=l;i>=r;i--)
#define clr(x,y) memset(x,y,sizeof(x))
#define maxn 100500
#define inf int(1e9)
using namespace std;
struct data{int obj,pre;
}e[maxn*2];
int head[maxn],pos[maxn],sum[maxn*22],ls[maxn*22],rs[maxn*22],dep[maxn],root[maxn*20];
int fa[maxn][22],v[maxn],tmp[maxn],hash[maxn],num[maxn];
int n,m,ans,tot,cnt,cnt2,idx,bin[22];
void insert(int x,int y){
    e[++tot].obj=y; e[tot].pre=head[x]; head[x]=tot;
}
int read(){
    int x=0,f=1; char ch=getchar();
    while (!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();}
    while (isdigit(ch)) {x=x*10+ch-'0'; ch=getchar();}
    return x*f;
}
int find(int x){
    int l=1,r=cnt;
    while (l<r){
        int mid=(l+r)/2;
        if (hash[mid]==x) return mid;
        if (x<hash[mid]) r=mid-1; else l=mid+1;
    }
    return l;
}
void dfs(int u){
    pos[u]=++idx; num[idx]=u;
    rep(i,1,20) if (dep[u]>bin[i]) fa[u][i]=fa[fa[u][i-1]][i-1];
    for (int j=head[u];j;j=e[j].pre){
        int v=e[j].obj;
        if (v!=fa[u][0]){
            fa[v][0]=u;
            dep[v]=dep[u]+1;
            dfs(v);
        }
    }    
}
void add(int l,int r,int x,int &y,int val){
    y=++cnt2;
    sum[y]=sum[x]+1;
    if (l==r) return;
    ls[y]=ls[x]; rs[y]=rs[x];
    int mid=(l+r)/2;
    if (val<=mid) add(l,mid,ls[x],ls[y],val);
    else add(mid+1,r,rs[x],rs[y],val);
}
int lca(int x,int y){
    if (dep[x]<dep[y]) swap(x,y);
    int t=dep[x]-dep[y];
    rep(i,0,20) if (t&bin[i]) x=fa[x][i];
    down(i,20,0) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    if (x!=y) return fa[x][0];
    return x;    
}
int ask(int x,int y,int k){
    int t=lca(x,y);
    int a=root[pos[x]],b=root[pos[y]],c=root[pos[t]],d=root[pos[fa[t][0]]];
    int l=1,r=cnt;
    while (l<r){
        int mid=(l+r)/2;
        int tmp=sum[ls[a]]+sum[ls[b]]-sum[ls[c]]-sum[ls[d]];
        if (k<=tmp) {a=ls[a],b=ls[b],c=ls[c],d=ls[d];r=mid;}
        else {k-=tmp; a=rs[a],b=rs[b],c=rs[c],d=rs[d]; l=mid+1;} 
    }
    return hash[l];
}
int main(){
    bin[0]=1; rep(i,1,20) bin[i]=bin[i-1]*2;
    n=read(); m=read();
    rep(i,1,n) v[i]=read(),tmp[i]=v[i];
    sort(tmp+1,tmp+1+n);
    hash[cnt=1]=tmp[1];
    rep(i,2,n) if (tmp[i]!=tmp[i-1]) hash[++cnt]=tmp[i];
    rep(i,1,n) v[i]=find(v[i]);
    rep(i,1,n-1){
        int x=read(),y=read();
        insert(x,y); insert(y,x);
    } 
    dep[1]=1; dfs(1);
    rep(i,1,n){
        int t=num[i];
        add(1,cnt,root[pos[fa[t][0]]],root[i],v[t]);
    }
    rep(i,1,m){
        int x=read(),y=read(),k=read();
        x=x^ans;
        ans=ask(x,y,k);
        if (i!=m) printf("%d\n",ans);
        else printf("%d",ans);
    }
      return 0;
}

 

posted on 2015-12-21 14:11  ctlchild  阅读(221)  评论(0编辑  收藏  举报

导航