SPOJ COT Count on a tree(可持久化线段树+倍增lca)

题意:给你颗树,灭个节点都有一个权值,询问你a到b上的路径的地k小

思路:这个题其实就是树上的第k小,主席树的本质还是类似于前缀和一样的结构,所以是完全相同的,所以我们在树上也可以用同样的方法,我们对于每一个节点进行建树,然后和普通的树上相同,ab之间的距离是等于

root[a]+root[b]-root[lca[a,b]]-root[fa[lca[a,b]]]

代码:

#include <bits/stdc++.h>
using namespace std;
const int maxn=1e5+7;
const int POW=18;
int num[maxn],sa[maxn];
int ls[maxn*40],rs[maxn*40];
int sum[maxn*40];
int root[maxn];
vector<int>mp[maxn];
int dis[maxn];
int p[maxn][POW];
int cnt;
int f[maxn];

void build(int l,int r,int &rt)
{
    rt=++cnt;
    sum[rt]=0;
    if(l>=r)return ;
    int mid=(l+r)>>1;
    build(l,mid,ls[rt]);
    build(mid+1,r,rs[rt]);
}
void update(int last,int p,int l,int r,int &rt)
{
    rt=++cnt;
    ls[rt]=ls[last];
    rs[rt]=rs[last];
    sum[rt]=sum[last]+1;
    if(l>=r)return ;
    int mid=(l+r)>>1;
    if(p<=mid)update(ls[last],p,l,mid,ls[rt]);
    else update(rs[last],p,mid+1,r,rs[rt]);
}
int query(int lrt,int rrt,int lcart,int lcafrt,int l,int r,int k)
{
    if(l>=r)return l;
    int mid=(l+r)>>1;
    int ans=sum[ls[rrt]]+sum[ls[lrt]]-sum[ls[lcart]]-sum[ls[lcafrt]];
    if(k<=ans)
        return query(ls[lrt],ls[rrt],ls[lcart],ls[lcafrt],l,mid,k);
    else
        return query(rs[lrt],rs[rrt],rs[lcart],rs[lcafrt],mid+1,r,k-ans);
}
void dfs(int u,int fa,int tot)
{
    f[u]=fa;
    dis[u]=dis[fa]+1;
    p[u][0]=fa;
    for(int i=1;i<POW;i++)
        p[u][i]=p[p[u][i-1]][i-1];
    update(root[fa],num[u],1,tot,root[u]);
    for(int i=0;i<mp[u].size();i++){
        int v=mp[u][i];
        if(v==fa)continue;
        dfs(v,u,tot);
    }
}
int lca(int a,int b)
{
    if(dis[a]>dis[b])swap(a,b);
    if(dis[a]<dis[b]){
        int del=dis[b]-dis[a];
        for(int i=0;i<POW;i++)
            if(del&(1<<i))b=p[b][i];
    }
    if(a!=b){
        for(int i=POW-1;i>=0;i--){
            if(p[a][i]!=p[b][i]){
                a=p[a][i];b=p[b][i];
            }
        }
        a=p[a][0];b=p[b][0];
    }
    return a;
}
int main()
{
    int n,m;
    while(~scanf("%d%d",&n,&m)){
        for(int i=0;i<n;i++)mp[i].clear();
        memset(dis,0,sizeof(dis));
        memset(p,0,sizeof(p));
        memset(f,0,sizeof(f));
        for(int i=1;i<=n;i++){
            scanf("%d",&num[i]);
            sa[i]=num[i];
        }
        cnt=0;
        sort(sa+1,sa+1+n);
        int tot=unique(sa+1,sa+1+n)-sa-1;
        for(int i=1;i<=n;i++){
            num[i]=lower_bound(sa+1,sa+tot+1,num[i])-sa;
        }
        int a,b,c;
        for(int i=1;i<n;i++){
            scanf("%d%d",&a,&b);
            mp[a].push_back(b);
            mp[b].push_back(a);
        }
        build(1,tot,root[0]);
        dfs(1,0,tot);
        for(int i=1;i<=m;i++){
            scanf("%d%d%d",&a,&b,&c);
            int t=lca(a,b);
            int id=query(root[a],root[b],root[t],root[f[t]],1,tot,c);
            printf("%d\n",sa[id]);
        }
    }
    return 0;
}
/*
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
2 5 2
2 5 3
2 5 4
7 8 2
*/

 

posted @ 2018-07-20 11:13  啦啦啦天啦噜  阅读(164)  评论(0编辑  收藏  举报