洛谷P2633 Count on a tree(主席树上树)
题目描述
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
输入输出格式
输入格式:第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
输出格式:M行,表示每个询问的答案。
输入输出样例
输入样例#1:
复制
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
输出样例#1: 复制
2
8
9
105
7
说明
HINT:
N,M<=100000
暴力自重。。。
来源:bzoj2588 Spoj10628.
本题数据为洛谷自造数据,使用CYaRon耗时5分钟完成数据制作。
题解:主席树上树
就是把根节点到该点的链用主席树维护
查询x到y的链上k小值就是将链拆成(1~x)+(1~y)-(1~lca(x,y))-(1~fa(lca(x,y))
显然这种东西是可以主席树维护的,然后就可以A掉了
代码如下:
#include<bits/stdc++.h> #define lson tr[now].l #define rson tr[now].r using namespace std; struct tree { int l,r,sum; }tr[5000050]; int a[100010],b[100010],n,m,cnt,f[100010][18],rt[100010],deep[100010]; vector<int> g[100010]; int init() { map<int,int> m; sort(b+1,b+n+1); int tot=unique(b+1,b+n+1)-b-1; for(int i=1;i<=tot;i++) { m[b[i]]=i; } for(int i=1;i<=n;i++) { a[i]=m[a[i]]; } } int push_up(int now) { tr[now].sum=tr[lson].sum+tr[rson].sum; } int insert(int &now,int fa,int l,int r,int pos) { if(l==r) { now=++cnt; tr[now].sum=tr[fa].sum+1; return 0; } int mid=(l+r)>>1; now=++cnt; tr[now].sum=tr[fa].sum+1; if(pos<=mid) { insert(lson,tr[fa].l,l,mid,pos); tr[now].r=tr[fa].r; } else { insert(rson,tr[fa].r,mid+1,r,pos); tr[now].l=tr[fa].l; } push_up(now); } int query(int t1,int t2,int t3,int t4,int l,int r,int k) { if(l==r) { return l; } int cnt=tr[tr[t1].l].sum+tr[tr[t2].l].sum-tr[tr[t3].l].sum-tr[tr[t4].l].sum; int mid=(l+r)>>1; if(cnt>=k) { query(tr[t1].l,tr[t2].l,tr[t3].l,tr[t4].l,l,mid,k); } else { query(tr[t1].r,tr[t2].r,tr[t3].r,tr[t4].r,mid+1,r,k-cnt); } } int dfs(int now,int ff,int dep) { deep[now]=dep; f[now][0]=ff; for(int i=1;i<=17;i++) f[now][i]=f[f[now][i-1]][i-1]; insert(rt[now],rt[ff],1,100000,a[now]); for(int i=0;i<g[now].size();i++) { if(g[now][i]==ff) continue; dfs(g[now][i],now,dep+1); } } int lca(int x,int y) { if(deep[x]<deep[y]) swap(x,y); for(int i=17;i>=0;i--) { if(deep[f[x][i]]>=deep[y]) x=f[x][i]; } if(x==y) return x; for(int i=17;i>=0;i--) { if(f[x][i]!=f[y][i]) { x=f[x][i]; y=f[y][i]; } } return f[x][0]; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) { scanf("%d",&a[i]); b[i]=a[i]; } init(); int from,to,k; for(int i=1;i<n;i++) { scanf("%d%d",&from,&to); g[from].push_back(to); g[to].push_back(from); } dfs(1,0,1); int ans=0; for(int i=1;i<=m;i++) { scanf("%d%d%d",&from,&to,&k); from^=ans; int l=lca(from,to); int fl=f[l][0]; printf("%d\n",ans=b[query(rt[from],rt[to],rt[l],rt[fl],1,100000,k)]); } }