[BZOJ 2588]Count on a tree
LCA+主席树(可持久化线段树)
取一个点为根,每棵线段树记录树上节点到根的链上的权在数轴上的分布(当然要离散化),
则对于两个点u,v的路径上的数在数轴上的分布可以表示为tree[u]+tree[v]-tree[lca(u,v)]-tree[fa(u,v)](可以随便画图YY一下),
然后就可以在得到的树上查询了。
代码:
#include<bits/stdc++.h> using namespace std; #define N ((1<<17)-1) int n,m,rv[N],v[N],dep[N],jp[N][20],cnt,p[N]; bool vis[N]; vector<vector<int> >G; struct SN{ SN *son[2]; int val; }sn[N*20],*root[N]; void build(SN&x,int l,int r) { if (l==r) return; int m=(l+r)>>1; x.son[0]=&sn[++cnt]; x.son[1]=&sn[++cnt]; build(*x.son[0],l,m); build(*x.son[1],m+1,r); } inline bool cmp(int a,int b){return rv[a]<rv[b];} void putin() { int i,x,y; scanf("%d%d",&n,&m); G.resize(n+1); for (i=1;i<=n;i++) scanf("%d",&rv[i]),p[i-1]=i; sort(p,p+n,cmp); for (i=0;i<n;i++) v[p[i]]=i; for (i=0;i<n-1;i++) { scanf("%d%d",&x,&y); G[x].push_back(y); G[y].push_back(x); } build(sn[0],0,n-1); root[0]=&sn[0]; } void rebuild(SN&x,int l,int r,int k) { x.val++; if (l==r) return; int m=(l+r)>>1,s; if (k<=m) s=0,r=m; else s=1,l=m+1; sn[++cnt]=*x.son[s]; x.son[s]=&sn[cnt]; rebuild(*x.son[s],l,r,k); } void dfs(int x,int f,int d) { int i; vis[x]=1; dep[x]=d; jp[x][0]=f; for (i=0;jp[x][i];i++) jp[x][i+1]=jp[jp[x][i]][i]; root[x]=&sn[++cnt]; *root[x]=*root[f]; rebuild(*root[x],0,n-1,v[x]); for (i=0;i<G[x].size();i++) if (!vis[G[x][i]]) dfs(G[x][i],x,d+1); } int search(SN&a,SN&b,SN&c,SN&d,int l,int r,int k) { if (l==r) return l; int val=(*a.son[0]).val+(*b.son[0]).val-(*c.son[0]).val-(*d.son[0]).val; int m=(l+r)>>1,s; if (val>=k) s=0,r=m; else s=1,k-=val,l=m+1; return search(*a.son[s],*b.son[s],*c.son[s],*d.son[s],l,r,k); } int LCA(int u,int v) { int i; if (dep[u]!=dep[v]) { if (dep[u]<dep[v]) swap(u,v); for (i=17;i>=0;i--) if (dep[u]-(1<<i)>=dep[v])u=jp[u][i]; } if (u==v) return u; for (i=17;i>=0;i--) if (jp[u][i]!=jp[v][i]) { u=jp[u][i]; v=jp[v][i]; } return jp[u][0]; } void answer() { int i,ans=0,u,v,k,lca; for (i=0;i<m;i++) { scanf("%d%d%d",&u,&v,&k); u^=ans; lca=LCA(u,v); ans=rv[p[search(*root[u],*root[v],*root[lca],*root[jp[lca][0]],0,n-1,k)]]; printf("%d",ans); if (i<m-1) printf("\n"); } } int main() { putin(); dfs(1,0,1); answer(); }