Count on a tree SPOJ - COT (树上第k小)

#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
struct node
{
    int ls,rs,sum;
    node(int ls=0,int rs=0,int sum=0)
    {
        this->ls=ls;this->rs=rs;this->sum=sum;
    }
}tree[maxn*40];
int a[maxn],v[maxn],T[maxn],grand[maxn][20],depth[maxn],n,cnt,tot;
vector<int>g[maxn];

inline int getval(int x)
{
    return lower_bound(v+1,v+cnt+1,x)-v;
}
int lca(int x,int y)
{
    if(depth[x]<depth[y]) swap(x,y);//使x是最深的那个点
    for(int i=floor(log(n)/log(2))+1; i>=0; i--)
        if(depth[grand[x][i]]>=depth[y])//如果蹦的话x的深度仍比y的深度深,那就蹦,否则不蹦
            x=grand[x][i];
    if(x!=y)
    {
        //此时x和y的深度相同,但不是同一个节点
        for(int i=floor(log(n)/log(2))+1; i>=0; i--)
        {
            if(grand[x][i]!=grand[y][i])//如果x和y蹦完了到了相同的节点,那么蹦到的节点一定大于或等于lca
            {//所以只有蹦到了不同的节点,才可以蹦(这样才能保证没蹦超出lca)
                x=grand[x][i];
                y=grand[y][i];
            }
        }
        x=grand[x][0];//离真正的lca只差一步(此时不管蹦多少都超,所以前面的循环只能使x是离lca只差一层的节点)
    }
    return x;
}
int build1(int l,int r)
{
    int rt=++tot;
    tree[rt].sum=0;
    if(l!=r)
    {
        int mid=(l+r)>>1;
        tree[rt].ls=build1(l,mid);
        tree[rt].rs=build1(mid+1,r);
    }
    return rt;
}
int build2(int l,int r,int last,int val)
{
    int rt=++tot;
    tree[rt]=tree[last];
    tree[rt].sum++;
    if(l!=r)
    {
        int mid=(l+r)>>1;
        if(val<=mid)
            tree[rt].ls=build2(l,mid,tree[last].ls,val);
        else
            tree[rt].rs=build2(mid+1,r,tree[last].rs,val);
    }
    return rt;
}
void dfs(int now)
{
    depth[now]=depth[grand[now][0]]+1;
    T[now]=build2(1,cnt,T[grand[now][0]],getval(a[now]));
    int sz=g[now].size();
    for(int i=0;i<sz;i++)
    {
        int vv=g[now][i];
        if(vv==grand[now][0]) continue;
        grand[vv][0]=now;
        dfs(vv);
    }
}
int query(int x,int y,int grandf,int grandff,int l,int r,int k)
{
    if(l==r) return l;
    int tmp=tree[tree[x].ls].sum+tree[tree[y].ls].sum-tree[tree[grandf].ls].sum-tree[tree[grandff].ls].sum;
    int mid=(l+r)>>1;
    if(k<=tmp)
        return query(tree[x].ls,tree[y].ls,tree[grandf].ls,tree[grandff].ls,l,mid,k);
    else
        return query(tree[x].rs,tree[y].rs,tree[grandf].rs,tree[grandff].rs,mid+1,r,k-tmp);
}

int main()
{
    int q;
    scanf("%d%d",&n,&q);
    for(int i=1;i<=n;i++) scanf("%d",&(a[i])),v[i]=a[i];
    for(int i=1;i<n;i++)
    {
        int u,vv;
        scanf("%d%d",&u,&vv);
        g[u].push_back(vv);
        g[vv].push_back(u);
    }
    sort(v+1,v+n+1);
    cnt=unique(v+1,v+n+1)-v-1;
    T[0]=build1(1,cnt);
    dfs(1);
    for(int i=1;i<=floor(log(n)/log(2))+1;i++)
        for(int j=1;j<=n;j++)
            grand[j][i]=grand[grand[j][i-1]][i-1];
    while(q--)
    {
        int u,vv,k;
        scanf("%d%d%d",&u,&vv,&k);
        int tmp=lca(u,vv);
        tmp=query(T[u],T[vv],T[tmp],T[grand[tmp][0]],1,cnt,k);
        printf("%d\n",v[tmp]);
    }
    return 0;
}

 

posted @ 2019-05-07 15:51  eason99  阅读(81)  评论(0编辑  收藏  举报