Count on a tree(bzoj 2588)
Description
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
Input
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
Output
M行,表示每个询问的答案。最后一个询问不输出换行符
Sample Input
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
Sample Output
2
8
9
105
7
8
9
105
7
/* 树上第K大 对于每一个节点,维护一棵权值线段树,记录它到根节点的状态,询问的时候,类似于区间第K大,这条链上的 总数就是sum[a]+sum[b]-sum[lca]-sum[fa[lca]] */ #include<cstdio> #include<iostream> #include<algorithm> #define N 200010 using namespace std; int a[N],b[N],val[N],n,m,len; int head[N],fa[N][25],dep[N],root[N],lc[N*20],rc[N*20],sum[N*20],cnt; struct node{int v,pre;}e[N*2]; void add(int i,int u,int v){ e[i].v=v; e[i].pre=head[u]; head[u]=i; } void pushup(int now){ sum[now]=sum[lc[now]]+sum[rc[now]]; } void change(int last,int &now,int x,int l,int r){ now=++cnt; if(l==r){ sum[now]=sum[last]+1; return; } int mid=l+r>>1; if(x<=mid) change(lc[last],lc[now],x,l,mid),rc[now]=rc[last]; else change(rc[last],rc[now],x,mid+1,r),lc[now]=lc[last]; pushup(now); } void dfs(int x,int f,int c){ dep[x]=c;fa[x][0]=f;change(root[f],root[x],a[x],1,len); for(int i=head[x];i;i=e[i].pre){ if(e[i].v==f) continue; dfs(e[i].v,x,c+1); } } void get_fa(){ for(int j=1;j<=20;j++) for(int i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1]; } int get_same(int u,int t){ for(int i=0;i<=20;i++) if(t&(1<<i)) u=fa[u][i]; return u; } int LCA(int u,int v){ if(dep[u]<dep[v]) swap(u,v); u=get_same(u,dep[u]-dep[v]); if(u==v) return u; for(int i=20;i>=0;i--) if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i]; return fa[u][0]; } int query(int a,int b,int A,int B,int k,int l,int r){ int x=sum[lc[a]]+sum[lc[b]]-sum[lc[A]]-sum[lc[B]]; if(l==r) return l; int mid=l+r>>1; if(k<=x) return query(lc[a],lc[b],lc[A],lc[B],k,l,mid); else return query(rc[a],rc[b],rc[A],rc[B],k-x,mid+1,r); } int main(){ scanf("%d%d",&n,&m); for(int i=1;i<=n;i++){ scanf("%d",&a[i]); b[i]=a[i]; } sort(b+1,b+n+1); len=unique(b+1,b+n+1)-b-1; for(int i=1;i<=n;i++){ int t=lower_bound(b+1,b+len+1,a[i])-b; val[t]=a[i];a[i]=t; } for(int i=1;i<n;i++){ int u,v;scanf("%d%d",&u,&v); add(i*2-1,u,v);add(i*2,v,u); } dfs(1,0,0);get_fa(); int ans=0; for(int i=1;i<=m;i++){ int u,v,k;scanf("%d%d%d",&u,&v,&k);u^=ans; int anc=LCA(u,v); ans=val[query(root[u],root[v],root[anc],root[fa[anc][0]],k,1,len)]; printf("%d",ans); if(i!=m) printf("\n"); } return 0; }