[bzoj2588] Count on a tree
一道树上差分的题。
每个点都建一个权值线段树,维护的是从这个点到根的链的信息。
这样就可以用主席树了,每个点的版本由其父节点的版本加上该点的权值得来。
剩下的就没什么了,主席树上二分查找第k小什么的。
我用的是树链剖分LCA。
本来早就写了这道题,一直卡着。
今天发现是第36行的if(val<=mid)写成了if(p<=mid)......
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 using namespace std; 5 6 int n,m; 7 int raw[100005]; 8 9 struct data 10 { 11 int rv,nv,id; 12 }dt[100005]; 13 14 int cmp(data q,data w){return q.rv<w.rv;} 15 int cmpb(data q,data w){return q.id<w.id;} 16 17 int hd[100005],nx[200005],to[200005],ec; 18 19 void edge(int af,int at) 20 { 21 to[++ec]=at; 22 nx[ec]=hd[af]; 23 hd[af]=ec; 24 } 25 26 int rt[100005]; 27 int tot[2400005],ls[2400005],rs[2400005],pc; 28 29 void edit(int &p,int last,int l,int r,int val) 30 { 31 p=++pc; 32 ls[p]=ls[last],rs[p]=rs[last]; 33 tot[p]=tot[last]+1; 34 if(l==r)return; 35 int mid=(l+r)>>1; 36 if(val<=mid)edit(ls[p],ls[last],l,mid,val); 37 else edit(rs[p],rs[last],mid+1,r,val); 38 } 39 40 int query(int x,int y,int ll,int fl,int l,int r,int k) 41 { 42 if(l==r)return l; 43 int lsum=tot[ls[x]]+tot[ls[y]]-tot[ls[ll]]-tot[ls[fl]]; 44 int mid=(l+r)>>1; 45 if(k<=lsum)return query(ls[x],ls[y],ls[ll],ls[fl],l,mid,k); 46 else return query(rs[x],rs[y],rs[ll],rs[fl],mid+1,r,k-lsum); 47 } 48 49 int dep[100005],sz[100005],f[100005]; 50 int son[100005],tp[100005]; 51 52 void dfs(int p,int fa) 53 { 54 sz[p]=1,dep[p]=dep[fa]+1,f[p]=fa; 55 edit(rt[p],rt[fa],1,n,dt[p].nv); 56 for(int i=hd[p];i;i=nx[i]) 57 { 58 if(to[i]==fa)continue; 59 dfs(to[i],p),sz[p]+=sz[to[i]]; 60 if(sz[to[i]]>sz[son[p]])son[p]=to[i]; 61 } 62 } 63 64 void findtp(int p) 65 { 66 if(p==son[f[p]])tp[p]=tp[f[p]]; 67 else tp[p]=p; 68 for(int i=hd[p];i;i=nx[i]) 69 if(to[i]!=f[p])findtp(to[i]); 70 } 71 72 int lca(int x,int y) 73 { 74 while(tp[x]!=tp[y])dep[tp[x]]>dep[tp[y]]?x=f[tp[x]]:y=f[tp[y]]; 75 return dep[x]>dep[y]?y:x; 76 } 77 78 int main() 79 { 80 scanf("%d%d",&n,&m); 81 for(int i=1;i<=n;i++)scanf("%d",&dt[i].rv),dt[i].id=i; 82 sort(dt+1,dt+n+1,cmp); 83 for(int i=1;i<=n;i++) 84 { 85 if(dt[i].rv==dt[i-1].rv)dt[i].nv=dt[i-1].nv; 86 else dt[i].nv=dt[i-1].nv+1,raw[dt[i].nv]=dt[i].rv; 87 } 88 sort(dt+1,dt+n+1,cmpb); 89 for(int i=1;i<n;i++) 90 { 91 int ff,tt; 92 scanf("%d%d",&ff,&tt); 93 edge(ff,tt); 94 edge(tt,ff); 95 } 96 dfs(1,0); 97 findtp(1); 98 int ans=0; 99 for(int i=1;i<=m;i++) 100 { 101 int x,y,k; 102 scanf("%d%d%d",&x,&y,&k); 103 x^=ans; 104 int l=lca(x,y); 105 int na=query(rt[x],rt[y],rt[l],rt[f[l]],1,n,k); 106 ans=raw[na]; 107 printf("%d\n",ans); 108 } 109 return 0; 110 }