BZOJ2588 count on a tree
题目链接:戳我
一道主席树的题。(其实我还是不太会做,所以参考了一些其他人写的博客)
(原先主席树就只打过模板qwq,所以特来好好的学习一下,发现其实应用还是很奇妙的)
题目的大概意思是求树上的链中第k大数——主席树啊!
我们想到平常的主席树是怎么搞的?它的原理不就是利用前缀和,然后做差嘛。现在放到树上了之后,就是树上差分。
这里的变形是把原先的继承上个节点的子节点变成继承树上的父亲的子节点。(这样才能做树上差分)查询的时候利用树上点差分的思想,查询u,v,lca(u,v),fa[lca(u,v)]即可。
代码如下:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define MAXN 100010
using namespace std;
int n,m,tt,cnt,pre,sum;
int head[MAXN],son[MAXN],siz[MAXN],top[MAXN];
int dep[MAXN],fa[MAXN],a[MAXN],rt[MAXN],p[MAXN];
struct Edge{int nxt,to;}edge[MAXN<<1];
struct Node{int ls,rs,v;}t[MAXN<<5];
inline void add(int from,int to)
{
edge[++tt].nxt=head[from],edge[tt].to=to,head[from]=tt;
edge[++tt].nxt=head[to],edge[tt].to=from,head[to]=tt;
}
inline void build(int &x,int l,int r)
{
x=++cnt;
int mid=(l+r)>>1;
if(l==r) return;
build(t[x].ls,l,mid);
build(t[x].rs,mid+1,r);
}
inline void update(int &x,int l,int r,int pos,int ff)
{
x=++cnt;
t[x].ls=t[ff].ls,t[x].rs=t[ff].rs,t[x].v=t[ff].v+1;
if(l==r) return;
int mid=(l+r)>>1;
if(pos<=mid) update(t[x].ls,l,mid,pos,t[ff].ls);
else update(t[x].rs,mid+1,r,pos,t[ff].rs);
}
inline int query(int cur1,int cur2,int cur3,int cur4,int l,int r,int k)
{
if(l==r) return l;
int tmp=t[t[cur1].ls].v+t[t[cur2].ls].v-t[t[cur3].ls].v-t[t[cur4].ls].v;
int mid=(l+r)>>1;
if(tmp>=k) return query(t[cur1].ls,t[cur2].ls,t[cur3].ls,t[cur4].ls,l,mid,k);
else return query(t[cur1].rs,t[cur2].rs,t[cur3].rs,t[cur4].rs,mid+1,r,k-tmp);
}
inline void dfs1(int now,int pre)
{
int maxx=-1;
dep[now]=dep[pre]+1;
fa[now]=pre;
siz[now]=1;
update(rt[now],1,sum,a[now],rt[pre]);
for(int i=head[now];i;i=edge[i].nxt)
{
int v=edge[i].to;
if(v==pre) continue;
dfs1(v,now);
siz[now]+=siz[v];
if(siz[v]>maxx) maxx=siz[v],son[now]=v;
}
}
inline void dfs2(int now,int topf)
{
top[now]=topf;
if(son[now]) dfs2(son[now],topf);
for(int i=head[now];i;i=edge[i].nxt)
{
int v=edge[i].to;
if(v==fa[now]||v==son[now]) continue;
dfs2(v,v);
}
}
inline int lca(int x,int y)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dep[x]<dep[y]) return x;
else return y;
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("ce.in","r",stdin);
#endif
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&p[i]);
memcpy(a,p,sizeof(p));
sort(&p[1],&p[1+n]);
sum=unique(&p[1],&p[1+n])-p-1;
for(int i=1;i<=n;i++)
a[i]=lower_bound(&p[1],&p[1+sum],a[i])-p;
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
}
dfs1(1,0);
dfs2(1,1);
build(rt[0],1,sum);
for(int i=1;i<=m;i++)
{
int u,v,k;
scanf("%d%d%d",&u,&v,&k);
u^=pre;
int cur=lca(u,v);
//printf("u=%d v=%d lca=%d\n",u,v,cur);
printf("%d\n",pre=p[query(rt[u],rt[v],rt[cur],rt[fa[cur]],1,sum,k)]);
}
return 0;
}