BZOJ2588: Spoj 10628. Count on a tree
题目:http://www.lydsy.com/JudgeOnline/problem.php?id=2588
lca+可持久化线段树
在树上建一棵可持久化线段树就可以了。
#include<cstring> #include<iostream> #include<cstdio> #include<algorithm> #define rep(i,l,r) for (int i=l;i<=r;i++) #define down(i,l,r) for (int i=l;i>=r;i--) #define clr(x,y) memset(x,y,sizeof(x)) #define maxn 100500 #define inf int(1e9) using namespace std; struct data{int obj,pre; }e[maxn*2]; int head[maxn],pos[maxn],sum[maxn*22],ls[maxn*22],rs[maxn*22],dep[maxn],root[maxn*20]; int fa[maxn][22],v[maxn],tmp[maxn],hash[maxn],num[maxn]; int n,m,ans,tot,cnt,cnt2,idx,bin[22]; void insert(int x,int y){ e[++tot].obj=y; e[tot].pre=head[x]; head[x]=tot; } int read(){ int x=0,f=1; char ch=getchar(); while (!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();} while (isdigit(ch)) {x=x*10+ch-'0'; ch=getchar();} return x*f; } int find(int x){ int l=1,r=cnt; while (l<r){ int mid=(l+r)/2; if (hash[mid]==x) return mid; if (x<hash[mid]) r=mid-1; else l=mid+1; } return l; } void dfs(int u){ pos[u]=++idx; num[idx]=u; rep(i,1,20) if (dep[u]>bin[i]) fa[u][i]=fa[fa[u][i-1]][i-1]; for (int j=head[u];j;j=e[j].pre){ int v=e[j].obj; if (v!=fa[u][0]){ fa[v][0]=u; dep[v]=dep[u]+1; dfs(v); } } } void add(int l,int r,int x,int &y,int val){ y=++cnt2; sum[y]=sum[x]+1; if (l==r) return; ls[y]=ls[x]; rs[y]=rs[x]; int mid=(l+r)/2; if (val<=mid) add(l,mid,ls[x],ls[y],val); else add(mid+1,r,rs[x],rs[y],val); } int lca(int x,int y){ if (dep[x]<dep[y]) swap(x,y); int t=dep[x]-dep[y]; rep(i,0,20) if (t&bin[i]) x=fa[x][i]; down(i,20,0) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; if (x!=y) return fa[x][0]; return x; } int ask(int x,int y,int k){ int t=lca(x,y); int a=root[pos[x]],b=root[pos[y]],c=root[pos[t]],d=root[pos[fa[t][0]]]; int l=1,r=cnt; while (l<r){ int mid=(l+r)/2; int tmp=sum[ls[a]]+sum[ls[b]]-sum[ls[c]]-sum[ls[d]]; if (k<=tmp) {a=ls[a],b=ls[b],c=ls[c],d=ls[d];r=mid;} else {k-=tmp; a=rs[a],b=rs[b],c=rs[c],d=rs[d]; l=mid+1;} } return hash[l]; } int main(){ bin[0]=1; rep(i,1,20) bin[i]=bin[i-1]*2; n=read(); m=read(); rep(i,1,n) v[i]=read(),tmp[i]=v[i]; sort(tmp+1,tmp+1+n); hash[cnt=1]=tmp[1]; rep(i,2,n) if (tmp[i]!=tmp[i-1]) hash[++cnt]=tmp[i]; rep(i,1,n) v[i]=find(v[i]); rep(i,1,n-1){ int x=read(),y=read(); insert(x,y); insert(y,x); } dep[1]=1; dfs(1); rep(i,1,n){ int t=num[i]; add(1,cnt,root[pos[fa[t][0]]],root[i],v[t]); } rep(i,1,m){ int x=read(),y=read(),k=read(); x=x^ans; ans=ask(x,y,k); if (i!=m) printf("%d\n",ans); else printf("%d",ans); } return 0; }