Count on a tree SPOJ - COT (树上第k小)
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
struct node
{
int ls,rs,sum;
node(int ls=0,int rs=0,int sum=0)
{
this->ls=ls;this->rs=rs;this->sum=sum;
}
}tree[maxn*40];
int a[maxn],v[maxn],T[maxn],grand[maxn][20],depth[maxn],n,cnt,tot;
vector<int>g[maxn];
inline int getval(int x)
{
return lower_bound(v+1,v+cnt+1,x)-v;
}
int lca(int x,int y)
{
if(depth[x]<depth[y]) swap(x,y);//使x是最深的那个点
for(int i=floor(log(n)/log(2))+1; i>=0; i--)
if(depth[grand[x][i]]>=depth[y])//如果蹦的话x的深度仍比y的深度深,那就蹦,否则不蹦
x=grand[x][i];
if(x!=y)
{
//此时x和y的深度相同,但不是同一个节点
for(int i=floor(log(n)/log(2))+1; i>=0; i--)
{
if(grand[x][i]!=grand[y][i])//如果x和y蹦完了到了相同的节点,那么蹦到的节点一定大于或等于lca
{//所以只有蹦到了不同的节点,才可以蹦(这样才能保证没蹦超出lca)
x=grand[x][i];
y=grand[y][i];
}
}
x=grand[x][0];//离真正的lca只差一步(此时不管蹦多少都超,所以前面的循环只能使x是离lca只差一层的节点)
}
return x;
}
int build1(int l,int r)
{
int rt=++tot;
tree[rt].sum=0;
if(l!=r)
{
int mid=(l+r)>>1;
tree[rt].ls=build1(l,mid);
tree[rt].rs=build1(mid+1,r);
}
return rt;
}
int build2(int l,int r,int last,int val)
{
int rt=++tot;
tree[rt]=tree[last];
tree[rt].sum++;
if(l!=r)
{
int mid=(l+r)>>1;
if(val<=mid)
tree[rt].ls=build2(l,mid,tree[last].ls,val);
else
tree[rt].rs=build2(mid+1,r,tree[last].rs,val);
}
return rt;
}
void dfs(int now)
{
depth[now]=depth[grand[now][0]]+1;
T[now]=build2(1,cnt,T[grand[now][0]],getval(a[now]));
int sz=g[now].size();
for(int i=0;i<sz;i++)
{
int vv=g[now][i];
if(vv==grand[now][0]) continue;
grand[vv][0]=now;
dfs(vv);
}
}
int query(int x,int y,int grandf,int grandff,int l,int r,int k)
{
if(l==r) return l;
int tmp=tree[tree[x].ls].sum+tree[tree[y].ls].sum-tree[tree[grandf].ls].sum-tree[tree[grandff].ls].sum;
int mid=(l+r)>>1;
if(k<=tmp)
return query(tree[x].ls,tree[y].ls,tree[grandf].ls,tree[grandff].ls,l,mid,k);
else
return query(tree[x].rs,tree[y].rs,tree[grandf].rs,tree[grandff].rs,mid+1,r,k-tmp);
}
int main()
{
int q;
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++) scanf("%d",&(a[i])),v[i]=a[i];
for(int i=1;i<n;i++)
{
int u,vv;
scanf("%d%d",&u,&vv);
g[u].push_back(vv);
g[vv].push_back(u);
}
sort(v+1,v+n+1);
cnt=unique(v+1,v+n+1)-v-1;
T[0]=build1(1,cnt);
dfs(1);
for(int i=1;i<=floor(log(n)/log(2))+1;i++)
for(int j=1;j<=n;j++)
grand[j][i]=grand[grand[j][i-1]][i-1];
while(q--)
{
int u,vv,k;
scanf("%d%d%d",&u,&vv,&k);
int tmp=lca(u,vv);
tmp=query(T[u],T[vv],T[tmp],T[grand[tmp][0]],1,cnt,k);
printf("%d\n",v[tmp]);
}
return 0;
}