BZOJ2588 树上静态第k大
题意翻译
给你一棵有n个结点的树,节点编号为1~n。
每个节点都有一个权值。
要求执行以下操作:
U V K:求从节点u到节点v的第k小权值。
输入输出格式
输入格式
第一行有两个整数n和m(n,m≤100000) 第二行有n个整数。 第i个整数表示第i个节点的权值。
接下来的n-1行中,每行包含两个整数u v,表示u和v之间有一条边。
接下来的m行,每行包含三个整数U V K,进行一次操作。
输出格式
对于每个操作,输出结果。
解题思路:和序列上的静态主席树差不多
我们先想序列上的做法。对于一个位置i,先令root[i]=root[i-1],然后再在root[i里面插入a[i]。这样每一个位置实际上维护了[1,n]的信息。
同理,放到树上,对于一个节点i,先令root[i]=root[fa[i]],然后再在root[i]里面插入a[i]。这样每一个位置实际上维护了这个节点到根的信息。
查询的时候,对于序列上的情况,我们只需要用root[r]-root[l-1],就可以得到需要的信息了。
放到树上,对于一个询问(u,v),我们需要用root[u]+root[v]-root[lca]-root[fa[lca]],得到需要的信息。
代码:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<vector> using namespace std; const int maxn=100005; int n,m,sz,a[maxn],head[maxn],fa[maxn][55],dep[maxn],cnt,tot,root[maxn*40]; struct node{ int l,r,sum; }T[maxn*40]; vector<int> v; struct Edge{ int u,v,next; }edge[maxn*2]; void add(int u,int v){ edge[tot].v=v; edge[tot].next=head[u]; head[u]=tot++; } int getid(int x){ return lower_bound(v.begin(),v.end(),x)-v.begin()+1; } void update(int l,int r,int &x,int y,int pos){ T[++cnt]=T[y],T[cnt].sum++,x=cnt; if(l==r) return; int mid=(l+r)/2; if(pos<=mid) update(l,mid,T[x].l,T[y].l,pos); else update(mid+1,r,T[x].r,T[y].r,pos); } void dfs(int u,int pre){ dep[u]=dep[pre]+1; fa[u][0]=pre; for(int i=1;i<=25;i++) fa[u][i]=fa[fa[u][i-1]][i-1]; for(int i=head[u];i!=-1;i=edge[i].next){ int v=edge[i].v; if(v==pre) continue; update(1,sz,root[v],root[u],getid(a[v])); dfs(v,u); } } int lca(int x,int y){ if(dep[x]<dep[y]) swap(x,y); for(int i=25;i>=0;i--){ if(dep[x]-(1<<i)>=dep[y]) x=fa[x][i]; } if(x==y) return x; for(int i=25;i>=0;i--){ if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; } return fa[x][0]; } int query(int l,int r,int x,int y,int lc,int flc,int k){ if(l==r) return l; int mid=(l+r)/2,sum=0; sum=T[T[x].l].sum+T[T[y].l].sum-T[T[lc].l].sum-T[T[flc].l].sum; if(k<=sum) return query(l,mid,T[x].l,T[y].l,T[lc].l,T[flc].l,k); else return query(mid+1,r,T[x].r,T[y].r,T[lc].r,T[flc].r,k-sum); } int main(){ scanf("%d%d",&n,&m); memset(head,-1,sizeof(head)); for(int i=1;i<=n;i++)scanf("%d",&a[i]),v.push_back(a[i]); sort(v.begin(),v.end()),v.erase(unique(v.begin(),v.end()),v.end()); sz=v.size(); for(int i=1;i<n;i++){ int u,v; scanf("%d%d",&u,&v); add(u,v); add(v,u); } update(1,sz,root[1],root[0],getid(a[1])); dfs(1,0); for(int i=1;i<=m;i++){ int x,y,k; scanf("%d%d%d",&x,&y,&k); int lc=lca(x,y); printf("%d\n",v[query(1,sz,root[x],root[y],root[lc],root[fa[lc][0]],k)-1]); } return 0; }