【BZOJ2588】Count on a tree 题解(主席树+LCA)
前言:其实就是主席树板子啦……只不过变成了树上的查询
--------------------------
题目大意:求树上$u$到$v$路径第$k$大数。
查询静态区间第$k$大肯定是用主席树。我们知道主席树有着优秀的性质:对于前缀和和树上差分等操作都是满足的。感性理解一下:我们在打主席树板子的时候,每次查询都是$query(rt[l-1],rt[r],1,len,k)$,然后$k$与$sum[ls[r]]-sum[ls[l-1]]$比较。所以在进行树上的询问时,我们只要把板子的操作换成$sum[u]+sum[v]-sum[lca]-sum[fa[lca]]$即可。建树的话根据$dfs$序遍历整颗树建立$n$颗权值线段树即可,顺便把树上结点的祖先结点也求了。我们就这样成功AC一道主席树板子题。
PS:一开始RE了,调试代码时发现是把$root$打成$tot$QAQ。
代码:
#include<bits/stdc++.h> #define int long long using namespace std; const int maxn=1000005; int fa[maxn][21],n,m,a[maxn],b[maxn],rt[maxn],tot,len,last,dep[maxn]; int ls[8000005],rs[8000005],sum[8000005]; int head[8000005],cnt; struct node { int next,to; }edge[8000005]; inline int getpos(int x) {return lower_bound(b+1,b+len+1,x)-b;} inline int read() { int x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();} while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();} return x*f; } inline void add(int from,int to) { edge[++cnt].next=head[from]; edge[cnt].to=to; head[from]=cnt; } inline int build(int l,int r) { int root=++tot,mid=(l+r)>>1; if (l<r) { ls[root]=build(l,mid); rs[root]=build(mid+1,r); } return root; } inline int update(int k,int l,int r,int root) { int dir=++tot; ls[dir]=ls[root],rs[dir]=rs[root];sum[dir]=sum[root]+1; int mid=(l+r)>>1; if (l<r) { if (k<=mid) ls[dir]=update(k,l,mid,ls[root]); else rs[dir]=update(k,mid+1,r,rs[root]); } return dir; } inline void dfs(int now,int f) { fa[now][0]=f;dep[now]=dep[f]+1; for (int i=1;i<=18;i++) fa[now][i]=fa[fa[now][i-1]][i-1]; rt[now]=update(getpos(a[now]),1,len,rt[f]); for (int i=head[now];i;i=edge[i].next) { int to=edge[i].to; if (to==f) continue; dfs(to,now); } } inline int LCA(int x,int y) { if (dep[x]<dep[y]) swap(x,y); for (int i=18;i>=0;i--) if (dep[fa[x][i]]>=dep[y]) x=fa[x][i]; if (x==y) return x; for (int i=18;i>=0;i--) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0]; } inline int query(int u,int v,int f,int ff,int l,int r,int k) { if (l==r) return l; int mid=(l+r)>>1; int x=sum[ls[u]]+sum[ls[v]]-sum[ls[f]]-sum[ls[ff]]; if (k<=x) return query(ls[u],ls[v],ls[f],ls[ff],l,mid,k); else return query(rs[u],rs[v],rs[f],rs[ff],mid+1,r,k-x); } inline int querypath(int u,int v,int k) { int lca=LCA(u,v); return query(rt[u],rt[v],rt[lca],rt[fa[lca][0]],1,len,k); } signed main() { n=read(),m=read(); for (int i=1;i<=n;i++) a[i]=read(),b[i]=a[i]; for (int i=1;i<n;i++) { int x=read(),y=read(); add(x,y);add(y,x); } sort(b+1,b+n+1); len=unique(b+1,b+n+1)-b-1; rt[0]=build(1,len); dfs(1,0); for (int i=1;i<=m;i++) { int u=read(),v=read(),k=read(); u=u^last; printf("%lld\n",last=b[querypath(u,v,k)]); } return 0; }