BZOJ - 2588 Spoj 10628. Count on a tree (可持久化线段树+LCA/树链剖分)
第一种方法,dfs序上建可持久化线段树,然后询问的时候把两点之间的所有树链扒出来做差。
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=1e5+10,inf=0x3f3f3f3f; 5 int hd[N],ne,n,n2,m,fa[N],son[N],siz[N],dep[N],top[N],dfn[N],rnk[N],tot,rt[N],ls[N*20],rs[N*20],val[N*20],tot2,a[N],b[N],ql[100],qr[100],nl,nr; 6 struct E {int v,nxt;} e[N<<1]; 7 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;} 8 void dfs1(int u,int f,int d) { 9 fa[u]=f,son[u]=0,siz[u]=1,dep[u]=d; 10 for(int i=hd[u]; ~i; i=e[i].nxt) { 11 int v=e[i].v; 12 if(v==fa[u])continue; 13 dfs1(v,u,d+1),siz[u]+=siz[v]; 14 if(siz[v]>siz[son[u]])son[u]=v; 15 } 16 } 17 void dfs2(int u,int tp) { 18 top[u]=tp,dfn[u]=++tot,rnk[tot]=u; 19 if(son[u])dfs2(son[u],tp); 20 for(int i=hd[u]; ~i; i=e[i].nxt) { 21 int v=e[i].v; 22 if(v==fa[u]||v==son[u])continue; 23 dfs2(v,v); 24 } 25 } 26 #define mid ((l+r)>>1) 27 void upd(int& u,int v,int x,int l=1,int r=n2) { 28 if(!u)u=++tot2; 29 val[u]=val[v]+1; 30 if(l==r)return; 31 if(x<=mid)upd(ls[u],ls[v],x,l,mid),rs[u]=rs[v]; 32 else upd(rs[u],rs[v],x,mid+1,r),ls[u]=ls[v]; 33 } 34 int ask(int u,int v,int k) { 35 for(nl=nr=0; top[u]!=top[v]; u=fa[top[u]]) { 36 if(dep[top[u]]<dep[top[v]])swap(u,v); 37 ql[nl++]=rt[dfn[top[u]]-1],qr[nr++]=rt[dfn[u]]; 38 } 39 if(dep[u]<dep[v])swap(u,v); 40 ql[nl++]=rt[dfn[v]-1],qr[nr++]=rt[dfn[u]]; 41 int l=1,r=n2; 42 while(l<r) { 43 int cnt=0; 44 for(int i=0; i<nr; ++i)cnt+=val[ls[qr[i]]]; 45 for(int i=0; i<nl; ++i)cnt-=val[ls[ql[i]]]; 46 if(k<=cnt) { 47 for(int i=0; i<nr; ++i)qr[i]=ls[qr[i]]; 48 for(int i=0; i<nl; ++i)ql[i]=ls[ql[i]]; 49 r=mid; 50 } else { 51 k-=cnt; 52 for(int i=0; i<nr; ++i)qr[i]=rs[qr[i]]; 53 for(int i=0; i<nl; ++i)ql[i]=rs[ql[i]]; 54 l=mid+1; 55 } 56 } 57 return l; 58 } 59 int main() { 60 memset(hd,-1,sizeof hd),ne=0; 61 scanf("%d%d",&n,&m); 62 for(int i=1; i<=n; ++i)scanf("%d",&a[i]); 63 for(int i=1; i<=n; ++i)b[i-1]=a[i]; 64 sort(b,b+n),n2=unique(b,b+n)-b; 65 for(int i=1; i<=n; ++i)a[i]=(lower_bound(b,b+n2,a[i])-b)+1; 66 for(int i=1; i<n; ++i) { 67 int u,v; 68 scanf("%d%d",&u,&v); 69 addedge(u,v),addedge(v,u); 70 } 71 tot=0,dfs1(1,0,1),dfs2(1,1); 72 memset(rt,0,sizeof rt),tot2=0; 73 for(int i=1; i<=n; ++i)upd(rt[i],rt[i-1],a[rnk[i]],1,n2); 74 for(int last=0; m--;) { 75 int u,v,k; 76 scanf("%d%d%d",&u,&v,&k),u^=last; 77 int ans=b[ask(u,v,k)-1]; 78 printf("%d\n",ans),last=ans; 79 } 80 return 0; 81 }
仔细一想这样似乎麻烦了点。因为没有修改操作,我们可以直接用子结点继承父节点的方式来建线段树,然后查询的时候,用u,v的线段树减去lca的线段树再减去lca父节点的线段树即可。
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=1e5+10,inf=0x3f3f3f3f; 5 int hd[N],ne,n,n2,m,fa[N],son[N],siz[N],dep[N],top[N],dfn[N],rnk[N],tot,rt[N],ls[N*20],rs[N*20],val[N*20],tot2,a[N],b[N],ql[100],qr[100],nl,nr; 6 struct E {int v,nxt;} e[N<<1]; 7 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;} 8 void dfs1(int u,int f,int d) { 9 fa[u]=f,son[u]=0,siz[u]=1,dep[u]=d; 10 for(int i=hd[u]; ~i; i=e[i].nxt) { 11 int v=e[i].v; 12 if(v==fa[u])continue; 13 dfs1(v,u,d+1),siz[u]+=siz[v]; 14 if(siz[v]>siz[son[u]])son[u]=v; 15 } 16 } 17 void dfs2(int u,int tp) { 18 top[u]=tp,dfn[u]=++tot,rnk[tot]=u; 19 if(son[u])dfs2(son[u],tp); 20 for(int i=hd[u]; ~i; i=e[i].nxt) { 21 int v=e[i].v; 22 if(v==fa[u]||v==son[u])continue; 23 dfs2(v,v); 24 } 25 } 26 #define mid ((l+r)>>1) 27 void upd(int& u,int v,int x,int l=1,int r=n2) { 28 if(!u)u=++tot2; 29 val[u]=val[v]+1; 30 if(l==r)return; 31 if(x<=mid)upd(ls[u],ls[v],x,l,mid),rs[u]=rs[v]; 32 else upd(rs[u],rs[v],x,mid+1,r),ls[u]=ls[v]; 33 } 34 void dfs3(int u) { 35 upd(rt[u],rt[fa[u]],a[u]); 36 for(int i=hd[u]; ~i; i=e[i].nxt) { 37 int v=e[i].v; 38 if(v==fa[u])continue; 39 dfs3(v); 40 } 41 } 42 int lca(int u,int v) { 43 for(; top[u]!=top[v]; u=fa[top[u]])if(dep[top[u]]<dep[top[v]])swap(u,v); 44 return dep[u]<dep[v]?u:v; 45 } 46 int ask(int u,int v,int w1,int w2,int k,int l=1,int r=n2) { 47 if(l==r)return l; 48 int cnt=val[ls[u]]+val[ls[v]]-val[ls[w1]]-val[ls[w2]]; 49 return k<=cnt?ask(ls[u],ls[v],ls[w1],ls[w2],k,l,mid):ask(rs[u],rs[v],rs[w1],rs[w2],k-cnt,mid+1,r); 50 } 51 int main() { 52 memset(hd,-1,sizeof hd),ne=0; 53 scanf("%d%d",&n,&m); 54 for(int i=1; i<=n; ++i)scanf("%d",&a[i]); 55 for(int i=1; i<=n; ++i)b[i-1]=a[i]; 56 sort(b,b+n),n2=unique(b,b+n)-b; 57 for(int i=1; i<=n; ++i)a[i]=(lower_bound(b,b+n2,a[i])-b)+1; 58 for(int i=1; i<n; ++i) { 59 int u,v; 60 scanf("%d%d",&u,&v); 61 addedge(u,v),addedge(v,u); 62 } 63 tot=0,dfs1(1,0,1),dfs2(1,1); 64 memset(rt,0,sizeof rt),tot2=0; 65 dfs3(1); 66 for(int last=0; m--;) { 67 int u,v,k; 68 scanf("%d%d%d",&u,&v,&k),u^=last; 69 int w=lca(u,v); 70 int ans=b[ask(rt[u],rt[v],rt[w],rt[fa[w]],k)-1]; 71 printf("%d\n",ans),last=ans; 72 } 73 return 0; 74 }
然后我又测试了倍增和RMQ求LCA的方法,发现居然还不如dfs序+可持久化线段树的方法快~~毕竟倍增和RMQ预处理的时间和空间复杂度都是$O(nlogn)$,而树剖只需要$O(n)$,而且查询速度也比较快。
倍增:
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=1e5+10,inf=0x3f3f3f3f; 5 int hd[N],ne,n,n2,m,fa[N][20],dep[N],rt[N],ls[N*20],rs[N*20],val[N*20],tot2,a[N],b[N]; 6 struct E {int v,nxt;} e[N<<1]; 7 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;} 8 #define mid ((l+r)>>1) 9 void upd(int& u,int v,int x,int l=1,int r=n2) { 10 if(!u)u=++tot2; 11 val[u]=val[v]+1; 12 if(l==r)return; 13 if(x<=mid)upd(ls[u],ls[v],x,l,mid),rs[u]=rs[v]; 14 else upd(rs[u],rs[v],x,mid+1,r),ls[u]=ls[v]; 15 } 16 void dfs(int u,int f,int d) { 17 fa[u][0]=f,dep[u]=d,upd(rt[u],rt[fa[u][0]],a[u]); 18 for(int i=1; i<20; ++i)fa[u][i]=fa[fa[u][i-1]][i-1]; 19 for(int i=hd[u]; ~i; i=e[i].nxt) { 20 int v=e[i].v; 21 if(v==fa[u][0])continue; 22 dfs(v,u,d+1); 23 } 24 } 25 int lca(int u,int v) { 26 if(dep[u]<dep[v])swap(u,v); 27 for(int i=19; dep[u]!=dep[v]; --i)if(dep[fa[u][i]]>=dep[v])u=fa[u][i]; 28 if(u==v)return u; 29 for(int i=19; i>=0; --i)if(fa[u][i]!=fa[v][i])u=fa[u][i],v=fa[v][i]; 30 return fa[u][0]; 31 } 32 int ask(int u,int v,int w1,int w2,int k,int l=1,int r=n2) { 33 if(l==r)return l; 34 int cnt=val[ls[u]]+val[ls[v]]-val[ls[w1]]-val[ls[w2]]; 35 return k<=cnt?ask(ls[u],ls[v],ls[w1],ls[w2],k,l,mid):ask(rs[u],rs[v],rs[w1],rs[w2],k-cnt,mid+1,r); 36 } 37 int main() { 38 memset(hd,-1,sizeof hd),ne=0; 39 scanf("%d%d",&n,&m); 40 for(int i=1; i<=n; ++i)scanf("%d",&a[i]); 41 for(int i=1; i<=n; ++i)b[i-1]=a[i]; 42 sort(b,b+n),n2=unique(b,b+n)-b; 43 for(int i=1; i<=n; ++i)a[i]=(lower_bound(b,b+n2,a[i])-b)+1; 44 for(int i=1; i<n; ++i) { 45 int u,v; 46 scanf("%d%d",&u,&v); 47 addedge(u,v),addedge(v,u); 48 } 49 memset(rt,0,sizeof rt),tot2=0; 50 dfs(1,0,1); 51 for(int last=0; m--;) { 52 int u,v,k; 53 scanf("%d%d%d",&u,&v,&k),u^=last; 54 int w=lca(u,v); 55 int ans=b[ask(rt[u],rt[v],rt[w],rt[fa[w][0]],k)-1]; 56 printf("%d\n",ans),last=ans; 57 } 58 return 0; 59 }
RMQ:
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=1e5+10,inf=0x3f3f3f3f; 5 int hd[N],ne,n,n2,m,fa[N],dep[N],pos[N],ST[N<<1][20],Log[N<<1],tot,rt[N],ls[N*20],rs[N*20],val[N*20],tot2,a[N],b[N]; 6 struct E {int v,nxt;} e[N<<1]; 7 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;} 8 #define mid ((l+r)>>1) 9 void upd(int& u,int v,int x,int l=1,int r=n2) { 10 if(!u)u=++tot2; 11 val[u]=val[v]+1; 12 if(l==r)return; 13 if(x<=mid)upd(ls[u],ls[v],x,l,mid),rs[u]=rs[v]; 14 else upd(rs[u],rs[v],x,mid+1,r),ls[u]=ls[v]; 15 } 16 void dfs(int u,int f,int d) { 17 fa[u]=f,dep[u]=d,ST[++tot][0]=u,pos[u]=tot,upd(rt[u],rt[fa[u]],a[u]); 18 for(int i=hd[u]; ~i; i=e[i].nxt) { 19 int v=e[i].v; 20 if(v==fa[u])continue; 21 dfs(v,u,d+1),ST[++tot][0]=u; 22 } 23 } 24 bool cmp(int a,int b) {return dep[a]<dep[b];} 25 void initST() { 26 for(int j=1; j<20; ++j) 27 for(int i=1; i+(1<<j)-1<=tot; ++i) 28 ST[i][j]=min(ST[i][j-1],ST[i+(1<<(j-1))][j-1],cmp); 29 } 30 int lca(int u,int v) { 31 int l=pos[u],r=pos[v]; 32 if(l>r)swap(l,r); 33 int i=Log[r-l+1]; 34 return min(ST[l][i],ST[r-(1<<i)+1][i],cmp); 35 } 36 int ask(int u,int v,int w1,int w2,int k,int l=1,int r=n2) { 37 if(l==r)return l; 38 int cnt=val[ls[u]]+val[ls[v]]-val[ls[w1]]-val[ls[w2]]; 39 return k<=cnt?ask(ls[u],ls[v],ls[w1],ls[w2],k,l,mid):ask(rs[u],rs[v],rs[w1],rs[w2],k-cnt,mid+1,r); 40 } 41 int main() { 42 for(int i=1; i<(N<<1); ++i)Log[i]=log2(i+0.5); 43 memset(hd,-1,sizeof hd),ne=0; 44 scanf("%d%d",&n,&m); 45 for(int i=1; i<=n; ++i)scanf("%d",&a[i]); 46 for(int i=1; i<=n; ++i)b[i-1]=a[i]; 47 sort(b,b+n),n2=unique(b,b+n)-b; 48 for(int i=1; i<=n; ++i)a[i]=(lower_bound(b,b+n2,a[i])-b)+1; 49 for(int i=1; i<n; ++i) { 50 int u,v; 51 scanf("%d%d",&u,&v); 52 addedge(u,v),addedge(v,u); 53 } 54 memset(rt,0,sizeof rt),tot2=0; 55 dfs(1,0,1),initST(); 56 for(int last=0; m--;) { 57 int u,v,k; 58 scanf("%d%d%d",&u,&v,&k),u^=last; 59 int w=lca(u,v); 60 int ans=b[ask(rt[u],rt[v],rt[w],rt[fa[w]],k)-1]; 61 printf("%d\n",ans),last=ans; 62 } 63 return 0; 64 }