【BZOJ2588】Spoj 10628. Count on a tree 主席树+LCA

【BZOJ2588】Spoj 10628. Count on a tree

Description

给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。

Input

第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。

Output

M行,表示每个询问的答案。

Sample Input

8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2

Sample Output

2
8
9
105
7

HINT

N,M<=100000
暴力自重。。。

题解:先树剖求出LCA,然后再树上搞一个主席树,第i棵线段树保存的是从根到i的路径上的所有点,然后查询的时候就在用a的线段树+b的线段树-lca(a,b)的线段树-fa(lca(a,b))的线段树就行了。

主席树写的还是不熟练,注意在建树的时候一定要按着DFS序建。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn=100010;
struct NUM
{
    int num,org;
}v[maxn];
struct NODE
{
    int siz,ls,rs;
}s[maxn*40];
int n,m,tot,nm,ans,cnt;
int ref[maxn],root[maxn],q[maxn];
int fa[maxn],size[maxn],top[maxn],son[maxn],deep[maxn],to[maxn<<1],next[maxn<<1],head[maxn];
bool cmp1(NUM a,NUM b)
{
    return a.num<b.num;
}
bool cmp2(NUM a,NUM b)
{
    return a.org<b.org;
}
void add(int a,int b)
{
    to[cnt]=b;
    next[cnt]=head[a];
    head[a]=cnt++;
}
void dfs1(int x)
{
    size[x]=1;
    for(int i=head[x];i!=-1;i=next[i])
    {
        if(to[i]!=fa[x])
        {
            fa[to[i]]=x,deep[to[i]]=deep[x]+1;
            dfs1(to[i]);
            size[x]+=size[to[i]];
            if(size[to[i]]>size[son[x]])    son[x]=to[i];
        }
    }
}
void dfs2(int x,int tp)
{
    top[x]=tp;
    q[++q[0]]=x;
    if(son[x])    dfs2(son[x],tp);
    for(int i=head[x];i!=-1;i=next[i])
        if(to[i]!=fa[x]&&to[i]!=son[x])
            dfs2(to[i],to[i]);
}
void insert(int x,int &y,int l,int r,int p)
{
    y=++tot;
    if(l==r)
    {
        s[y].siz=s[x].siz+1;
        return ;
    }
    int mid=l+r>>1;
    if(p<=mid)    s[y].rs=s[x].rs,insert(s[x].ls,s[y].ls,l,mid,p);
    else    s[y].ls=s[x].ls,insert(s[x].rs,s[y].rs,mid+1,r,p);
    s[y].siz=s[s[y].ls].siz+s[s[y].rs].siz;
}
void query(int a,int b,int c,int d,int l,int r,int p)
{
    if(l==r)
    {
        ans=ref[l];
        return ;
    }
    int mid=l+r>>1;
    if(s[s[a].ls].siz+s[s[b].ls].siz-s[s[c].ls].siz-s[s[d].ls].siz>=p)    query(s[a].ls,s[b].ls,s[c].ls,s[d].ls,l,mid,p);
    else    query(s[a].rs,s[b].rs,s[c].rs,s[d].rs,mid+1,r,p-s[s[a].ls].siz-s[s[b].ls].siz+s[s[c].ls].siz+s[s[d].ls].siz);
}
int getlca(int x,int y)
{
    while(top[x]!=top[y])
    {
        if(deep[top[x]]>deep[top[y]])    x=fa[top[x]];
        else    y=fa[top[y]];
    }
    if(deep[x]<deep[y])    return x;
    return y;
}
int readin()
{
    int ret=0,f=1;    char gc=getchar();
    while(gc<'0'||gc>'9')    {if(gc=='-')f=-f;    gc=getchar();}
    while(gc>='0'&&gc<='9')    ret=ret*10+gc-'0',gc=getchar();
    return ret*f;
}
int main()
{
    n=readin(),m=readin();
    memset(head,-1,sizeof(head));
    int i,a,b,c,d,e;
    for(i=1;i<=n;i++)    v[i].num=readin(),v[i].org=i;
    sort(v+1,v+n+1,cmp1);
    ref[nm]=-1;
    for(i=1;i<=n;i++)
    {
        if(v[i].num>ref[nm])    ref[++nm]=v[i].num;
        v[i].num=nm;
    }
    sort(v+1,v+n+1,cmp2);
    for(i=1;i<n;i++)
    {
        a=readin(),b=readin();
        add(a,b),add(b,a);
    }
    deep[1]=1,dfs1(1),dfs2(1,1);
    for(i=1;i<=n;i++)    insert(root[fa[q[i]]],root[q[i]],1,nm,v[q[i]].num);
    for(i=1;i<=m;i++)
    {
        scanf("%d%d%d",&a,&b,&e);
        a^=ans;
        int c=getlca(a,b),d=fa[c];
        query(root[a],root[b],root[c],root[d],1,nm,e);
        printf("%d",ans);
        if(i!=m)    printf("\n");
    }
    return 0;
}
posted @ 2017-01-17 11:30  CQzhangyu  阅读(326)  评论(2编辑  收藏  举报