【bzoj 2588】Spoj 10628. Count on a tree

Description

给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。

Input

第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。

Output

M行,表示每个询问的答案。最后一个询问不输出换行符

Sample Input

8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2

Sample Output

2
8
9
105
7

HINT

N,M<=100000
暴力自重。。。
 
 
在树上建主席树,统计答案时为避免LCA被减两次,计算sum[x]+sum[y]-sum[lca(x,y)]-sum[fa[lca(x,y)]]。
 1 #include<cstdio>
 2 #include<algorithm>
 3 #include<cstring>
 4 #define LL long long
 5 using namespace std;
 6 const int N=1e5+5;
 7 int n,m,tot,x,y,rk,cnte,ind,cnt,temp,lastans;
 8 int v[N],tmp[N],first[N],id[N],num[N],root[N];
 9 int deep[N],fa[N][20];
10 struct node{int lc,rc,sum;}tr[N*20];
11 struct edge{int to,next;}e[N*2];
12 int read()
13 {
14     int x=0,f=1;char c=getchar();
15     while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
16     while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
17     return x*f;
18 }
19 void ins(int u,int v){e[++cnte]=(edge){v,first[u]};first[u]=cnte;}
20 void insert(int u,int v){ins(u,v);ins(v,u);}
21 void dfs(int x)
22 {
23     ind++;id[x]=ind;num[ind]=x;
24     for(int i=1;(1<<i)<=deep[x];i++)fa[x][i]=fa[fa[x][i-1]][i-1];
25     for(int i=first[x];i;i=e[i].next)
26     {
27         if(deep[e[i].to])continue;
28         fa[e[i].to][0]=x;deep[e[i].to]=deep[x]+1;dfs(e[i].to);
29     }
30 }
31 int lca(int ri,int rj)
32 {
33     if(deep[ri]<deep[rj])swap(ri,rj);
34     int d=deep[ri]-deep[rj];
35     for(int i=0;(1<<i)<=d;i++)if((1<<i)&d)ri=fa[ri][i];
36     if(ri==rj)return ri;
37     for(int i=17;i>=0;i--)
38         if((1<<i)<=deep[ri]&&fa[ri][i]!=fa[rj][i])
39             ri=fa[ri][i],rj=fa[rj][i];
40     return fa[ri][0];
41 }
42 void update(int &x,int last,int L,int R,int num)
43 {
44     x=++cnt;tr[x].sum=tr[last].sum+1;
45     if(L==R)return;
46     tr[x].lc=tr[last].lc;tr[x].rc=tr[last].rc;
47     int mid=(L+R)>>1;
48     if(num<=mid)update(tr[x].lc,tr[last].lc,L,mid,num);
49     else update(tr[x].rc,tr[last].rc,mid+1,R,num);
50 }
51 int query(int x,int y,int rk)
52 {
53     int a=x,b=y,c=lca(x,y),d=fa[c][0];
54     a=root[id[x]];b=root[id[b]];c=root[id[c]];d=root[id[d]];
55     int L=1,R=tot;
56     while(L<R)
57     {
58         int mid=(L+R)>>1;
59         temp=tr[tr[a].lc].sum+tr[tr[b].lc].sum-tr[tr[c].lc].sum-tr[tr[d].lc].sum;
60         if(temp>=rk)R=mid,a=tr[a].lc,b=tr[b].lc,c=tr[c].lc,d=tr[d].lc;
61         else rk-=temp,L=mid+1,a=tr[a].rc,b=tr[b].rc,c=tr[c].rc,d=tr[d].rc;
62     }
63     return tmp[L];
64 }
65 int main()
66 {
67     n=read();m=read();
68     for(int i=1;i<=n;i++)v[i]=tmp[i]=read();
69     sort(tmp+1,tmp+n+1);tot=unique(tmp+1,tmp+n+1)-tmp-1;
70     for(int i=1;i<=n;i++)v[i]=lower_bound(tmp+1,tmp+tot+1,v[i])-tmp;
71     for(int i=1;i<n;i++)x=read(),y=read(),insert(x,y);
72     deep[1]=1;dfs(1);
73     for(int i=1;i<=n;i++)
74     {
75         temp=num[i];
76         update(root[i],root[id[fa[temp][0]]],1,tot,v[temp]);
77     }
78     for(int i=1;i<=m;i++)
79     {
80         x=read();y=read();rk=read();x^=lastans;
81         lastans=query(x,y,rk);printf("%d",lastans);
82         if(i!=m)printf("\n");
83     }
84     return 0;
85 }
View Code

 

posted @ 2017-12-24 14:04  Zsnuo  阅读(207)  评论(0编辑  收藏  举报