spoj COT - Count on a tree(主席树 +lca,树上第K大)

您将获得一个包含N个节点的树树节点的编号从1Ñ每个节点都有一个整数权重。

我们会要求您执行以下操作:

  • uvk:询问从节点u到节点v的路径上的第k个最小权重

输入

在第一行中有两个整数Ñ中号N,M <= 100000)

在第二行中有N个整数。第i个整数表示第i个节点的权重。

在接下来的N-1行中,每行包含两个整数u v,它描述了一个边(uv)。

在接下来的M行中,每行包含三个整数u v k,这意味着要求从节点u到节点v的路径上的第k个最小权重的操作

 

解题思路:
首先对于求第K小的问题 我们可以用主席树搞 ,没有问题,
但是对于一个树形结构,我们需要将其转化为线性,然后需要树剖才能做.

然后考虑链上的第k值怎么维护 ,
发现如果树剖计算的话 维护不了啊
因为(u,v)的路 可能在很多个链上,那么不能对每个求第K值,这样明显是错误的啊,

然后我们知道主席树其实就是维护了一个前缀和

那么我们可以对每一个节点到根节点建立前缀和,就能找任意一个节点到根节点的第K值,
那么根据主席树的性质,我们就能够计算(u,v)的路上的第K值了
只要在查询的时候稍改变一下就行了

cnt = sum[ls[u]]+sum[ls[v]]-sum[ls[lca(u,v)]]-sum[ls[fa[lca(u,v)]]];

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<vector>
#include<map>
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5+100;
typedef long long LL;
int rt[N*20], ls[N*20], rs[N*20], sum[N*20];
int fa[2*N][30], dep[2*N], vis[N];
int a[N], b[N], tot, cnt, head[N], len;
struct node
{
    int to, next;
} p[2*N];
void init()
{
    memset(head,-1,sizeof(head));
    memset(vis,0,sizeof(vis));
    cnt=0;
    return ;
}
void add(int u,int v)
{
    p[cnt].to=v,p[cnt].next=head[u];head[u]=cnt++;
    p[cnt].to=u,p[cnt].next=head[v];head[v]=cnt++;
    return ;
}
void build(int &o,int l,int r)
{
    o= ++tot,sum[o]=0;
    if(l==r) return ;
    int mid=(l+r)/2;
    build(ls[o],l,mid);
    build(rs[o],mid+1,r);
    return ;
}
void update(int &o,int l,int r,int last,int p)
{
    o= ++tot;
    ls[o]=ls[last],rs[o]=rs[last];
    sum[o]=sum[last]+1;
    if(l==r) return ;
    int mid=(l+r)/2;
    if(p<=mid) update(ls[o],l,mid,ls[last],p);
    else update(rs[o],mid+1,r,rs[last],p);
    return ;
}
int query(int ss,int tt,int s1,int t1,int l,int r,int cnt)
{
    if(l==r) return l;
    int tmp=sum[ls[tt]]+sum[ls[ss]]-sum[ls[s1]]-sum[ls[t1]];
    int mid=(l+r)/2;
    if(tmp>=cnt) return query(ls[ss],ls[tt],ls[s1],ls[t1],l,mid,cnt);
    else return query(rs[ss],rs[tt],rs[s1],rs[t1],mid+1,r,cnt-tmp);
}
void dfs(int u,int d,int f,int root)
{
    vis[u]=1,dep[u]=d,fa[u][0]=f;
    update(rt[u],1,len,root,a[u]);
    root=rt[u];
    for(int i=head[u];i!=-1;i=p[i].next)
    {
        int v=p[i].to;
        if(vis[v]) continue;
        dfs(v,d+1,u,root);
    }
    return ;
}
void lca(int n)
{
    int k=(int)(log(1.0*n)/log(2.0));
    for(int i=1;i<=k;i++)
    {
        for(int j=1;j<=n;j++)
        {
            fa[j][i]=fa[fa[j][i-1]][i-1];
        }
    }
    return ;
}
int get(int x,int y,int n)
{
    if(dep[x]<dep[y]) swap(x,y);
    int k=(int)(log(1.0*n)/log(2.0));
    int d=dep[x]-dep[y];
    for(int i=0;i<=k;i++)
        if((d&(1<<i))) x=fa[x][i];
    if(x==y) return x;
    for(int i=k;i>=0;i--)
    {
        if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    }
    return fa[x][0];
}
 
int main()
{
    int t, n, q;
    scanf("%d %d", &n, &q);
    for(int i=1; i<=n; i++) scanf("%d", &a[i]), b[i]=a[i];
    sort(b+1,b+n+1);
    len=unique(b+1,b+n+1)-(b+1);
    tot=0;
    build(rt[0],1,len);
    for(int i=1; i<=n; i++)  a[i]=lower_bound(b+1,b+len+1,a[i])-(b);
    init();
    for(int i=0;i<n-1;i++)
    {
        int x, y;
        scanf("%d %d", &x, &y);
        add(x,y);
    }
    dfs(1,1,0,rt[0]);
    lca(n);
    while(q--)
    {
        int l, r, x;
        scanf("%d %d %d", &l, &r, &x);
        int pos=get(l,r,n);
        printf("%d\n",b[query(rt[l],rt[r],rt[pos],rt[fa[pos][0]],1,len,x)]);
    }
    return 0;
}
View Code

 

 
posted @ 2018-10-07 11:11  shuai_hui  阅读(182)  评论(0编辑  收藏  举报