[DP优化方法之虚树]
首先我们看一篇文章 转自xyz:
给出一棵树.
每次询问选择一些点,求一些东西.这些东西的特点是,许多未选择的点可以通过某种方式剔除而不影响最终结果.
于是就有了建虚树这个技巧.....
我们可以用log级别的时间求出点对间的lca....
那么,对于每个询问我们根据原树的信息重新建树,这棵树中要尽量少地包含未选择节点. 这棵树就叫做虚树.
接下来所说的"树"均指虚树,原来那棵树叫做"原树".
构建过程如下:
按照原树的dfs序号(记为dfn)递增顺序遍历选择的节点. 每次遍历节点都把这个节点插到树上.
首先虚树一定要有一个根. 随便扯一个不会成为询问点的点作根.
维护一个栈,它表示在我们已经(用之前的那些点)构建完毕的虚树上,以最后一个插入的点为端点的DFS链.
设最后插入的点为p(就是栈顶的点),当前遍历到的点为x.我们想把x插入到我们已经构建的树上去.
求出lca(p,x),记为lca.有两种情况:
1.p和x分立在lca的两棵子树下.
2.lca是p.
(为什么lca不能是x?
因为如果lca是x,说明dfn(lca)=dfn(x)<dfn(a),而我们是按照dfs序号遍历的,于是dfn(a)<dfn(x),矛盾.)
对于第二种情况,直接在栈中插入节点x即可,不要连接任何边(后面会说为什么).
对于第一种情况,要仔细分析.
我们是按照dfs序号遍历的(因为很重要所以多说几遍......),有dfn(x)>dfn(p)>dfn(lca).
这说明什么呢? 说明一件很重要的事:我们已经把lca所引领的子树中,p所在的子树全部遍历完了!
简略的证明:如果没有遍历完,那么肯定有一个未加入的点h,满足dfn(h)<dfn(x),
我们按照dfs序号递增顺序遍历的话,应该把h加进来了才能考虑x.
这样,我们就直接构建lca引领的,p所在的那个子树. 我们在退栈的时候构建子树.
p所在的子树如果还有其它部分,它一定在之前就构建好了(所有退栈的点都已经被正确地连入树中了),就剩那条链.
如何正确地把p到lca那部分连进去呢?
设栈顶的节点为p,栈顶第二个节点为q.
重复以下操作:
如果dfn(q)>dfn(lca),可以直接连边q->p,然后退一次栈.
如果dfn(q)=dfn(lca),说明q=lca,直接连边lca->p,此时子树已经构建完毕.
如果dfn(q)<dfn(lca),说明lca被p与q夹在中间,此时连边lca->q,退一次栈,再把lca压入栈.此时子树构建完毕.
如果不理解这样操作的缘由可以画画图.....
最后,为了维护dfs链,要把x压入栈. 整个过程就是这样.....
然后就是我自己的理解了 我觉得我的理解虽然不是很严谨但是很容易懂
其实说白了就是如果我找到一个点不在这条链上 然后我们就跳栈顶的点使得栈顶的点和第二栈顶的点夹着lca 当然有可能第二栈顶的点就是lca 每次跳的时候连边
然后弹掉栈顶的点 如果现在栈顶的点不是lca就把lca塞进去 不是now的点就把now塞进去(这个应该是怕同此询问有重复的点吧 我去掉也ac)
然后的话虚树解决的就是总询问点数很少 询问次数很多的题 然后后面的记得清空就好
top=0; S[++top]=1; Plen=0; P[++Plen]=1; for(LL i=1;i<=K;i++) { LL now=H[i]; LL f=lca(S[top],now); while(dfn[S[top-1]]>dfn[f]){ins(1,S[top-1],S[top],0); top--;} if(dfn[S[top]]>dfn[f]){ins(1,f,S[top],0); top--;} if(S[top]!=f) S[++top]=f,P[++Plen]=f; S[++top]=now,P[++Plen]=now; } while(top>1){ins(1,S[top-1],S[top],0); top--;}
h数组是询问的点要按dfn序排一下
[Sdoi2011消耗战 |
这是一道模版题 找到所有点建完虚树后 然后dp 要删去一些边且费用最小 想一想 真正有用的也就只是这些点还有lca的点 所以的话dp一下 要不选下面点一直到根的最小值的和 要不就选lca到根最小值
#include<cstdio> #include<iostream> #include<cstring> #include<algorithm> #include<cstdlib> #include<cmath> #include<vector> #include<climits> #define Maxn 250010 using namespace std; typedef long long LL; struct node{LL x,y,next,d;}edge[2][Maxn*2]; LL len[2],first[2][Maxn]; void ins(LL k,LL x,LL y,LL d){len[k]++; edge[k][len[k]].x=x; edge[k][len[k]].y=y; edge[k][len[k]].d=d; edge[k][len[k]].next=first[k][x]; first[k][x]=len[k];} LL dep[Maxn],fa[Maxn][21],minx[Maxn]; LL dfn[Maxn],id=0; LL N,M; void Dfs(LL x,LL f) { dfn[x]=++id; for(LL k=first[0][x];k!=-1;k=edge[0][k].next) { LL y=edge[0][k].y; if(y!=f){dep[y]=dep[x]+1; fa[y][0]=x; minx[y]=min(minx[x],edge[0][k].d); Dfs(y,x);} } } LL H[Maxn],K; LL top,S[Maxn]; bool Cmp(const LL &x,const LL &y){return dfn[x]<dfn[y];} LL lca(LL x,LL y) { if(dep[x]<dep[y]) swap(x,y); LL deep=dep[x]-dep[y]; for(LL i=20;i>=0;i--) if(deep>=(1<<i)){deep-=(1<<i); x=fa[x][i];} if(x==y) return x; for(LL i=20;i>=0;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0]; } LL F[Maxn]; bool C[Maxn]; void DP(LL x) { F[x]=minx[x]; if(C[x]) return ; LL tmp=0; for(LL k=first[1][x];k!=-1;k=edge[1][k].next) { LL y=edge[1][k].y; DP(y); tmp+=F[y]; } if(tmp<F[x]) F[x]=tmp; } LL P[Maxn],Plen; void Solve() { for(LL i=1;i<=Plen;i++) first[1][P[i]]=-1; len[1]=0; scanf("%lld",&K); for(LL i=1;i<=K;i++) scanf("%lld",&H[i]),C[H[i]]=1; sort(H+1,H+K+1,Cmp); top=0; S[++top]=1; Plen=0; P[++Plen]=1; for(LL i=1;i<=K;i++) { LL now=H[i]; LL f=lca(S[top],now); while(dfn[S[top-1]]>dfn[f]){ins(1,S[top-1],S[top],0); top--;} if(dfn[S[top]]>dfn[f]){ins(1,f,S[top],0); top--;} if(S[top]!=f) S[++top]=f,P[++Plen]=f; if(S[top]!=now) S[++top]=now,P[++Plen]=now; } while(top>1){ins(1,S[top-1],S[top],0); top--;} DP(1); printf("%lld\n",F[1]); for(LL i=1;i<=K;i++) C[H[i]]=0; } int main() { scanf("%lld",&N); len[0]=0; memset(first[0],-1,sizeof(first[0])); for(LL i=1;i<N;i++){LL x,y,d; scanf("%lld%lld%lld",&x,&y,&d); ins(0,x,y,d); ins(0,y,x,d);} dep[1]=1; for(LL i=1;i<=N;i++) minx[i]=LLONG_MAX; Dfs(1,0); for(LL j=1;j<=20;j++) for(LL i=1;i<=N;i++) fa[i][j]=fa[fa[i][j-1]][j-1]; scanf("%lld",&M); len[1]=0; memset(first[1],-1,sizeof(first[1])); memset(C,0,sizeof(C)); for(LL i=1;i<=M;i++) Solve(); return 0; }
[Hnoi2014]世界树 |
这一道题就比较劲了
一些点管辖整个树 这些点是给定的 首先我们建一颗虚树 然后因为虚树上有一些点是lca的 也就是空的 我们要把这些点dp一下看看最近去到哪里
然后的话对于虚树上每两个相临的节点 我们二分这两个节点的链 也就是原树上的链 然后找到中间点 切开之后分别属于那两边
但是我们忽略了一个地方 就是有一些点没有被找过 他们为那些询问点下面的 而且下面没有询问点了 那怎么办呢 这些点肯定是跟着上面祖先选什么我就选什么的 这样的话就在祖先那里统计一下扫过了多少个点 剩下没被扫过的就是祖先的
说起来容易打起来难,这道算是经典题 不做不算是会虚树
#include<cstdio> #include<iostream> #include<cstring> #include<algorithm> #include<cstdlib> #include<cmath> #include<climits> #define Maxn 300010 using namespace std; const int inf=1e9; struct node { int x,y,next,d; }edge[2][Maxn*2]; int len[2],first[2][Maxn]; void ins(int k,int x,int y,int d){len[k]++; edge[k][len[k]].x=x; edge[k][len[k]].y=y; edge[k][len[k]].next=first[k][x]; first[k][x]=len[k];} int N,Q; int deep[Maxn],size[Maxn],fa[21][Maxn]; int dfn[Maxn],id=0; int unc[Maxn]; bool Cmp(const int &x,const int &y){return dfn[x]<dfn[y];} void Dfs(int x,int f) { size[x]=1; dfn[x]=++id; for(int k=first[0][x];k!=-1;k=edge[0][k].next) { int y=edge[0][k].y; if(y!=f) { deep[y]=deep[x]+1; fa[0][y]=x; Dfs(y,x); size[x]+=size[y]; } } } int lca(int x,int y) { if(deep[x]<deep[y]) swap(x,y); int d=(deep[x]-deep[y]); for(int i=20;i>=0;i--) if((1<<i)<=d) d-=(1<<i),x=fa[i][x]; if(x==y) return x; for(int i=20;i>=0;i--) if(fa[i][x]!=fa[i][y]) x=fa[i][x],y=fa[i][y]; return fa[0][x]; } int H[Maxn]; int P[Maxn],S[Maxn],top,plen; bool C[Maxn]; int dis(int x,int y){return deep[x]+deep[y]-2*deep[lca(x,y)];} pair<int,int> F1[Maxn],F2[Maxn],G[Maxn]; void Dfs1(int x) { if(C[x]) F1[x]=F2[x]=make_pair(0,x); for(int k=first[1][x];k!=-1;k=edge[1][k].next) { int y=edge[1][k].y; Dfs1(y); if(!C[x]) { int D=dis(x,y); if((F1[x].first>F1[y].first+D)||(F1[x].first==F1[y].first+D&&F1[x].second>F1[y].second)) F2[x]=F1[x],F1[x]=make_pair(F1[y].first+D,F1[y].second); else if((F2[x].first>F1[y].first+D)||(F2[x].first==F1[y].first+D&&F2[x].second>F1[y].second)) F2[x]=make_pair(F1[y].first+D,F1[y].second); } } } int F[Maxn]; void Dfs2(int x,int f) { if(!C[x]) { G[x].first=G[f].first+dis(x,f); G[x].second=G[f].second; if(F1[f].second==F1[x].second) { if((F2[f].first+dis(f,x)<G[x].first)||(F2[f].first+dis(f,x)==G[x].first&&F2[f].second<G[x].second)) G[x].second=F2[f].second,G[x].first=F2[f].first+dis(f,x); } else if((F1[f].first+dis(f,x)<G[x].first)||(F1[f].first+dis(f,x)==G[x].first&&F1[f].second<G[x].second)) G[x].second=F1[f].second,G[x].first=F1[f].first+dis(f,x); } else G[x]=make_pair(0,x); if(C[x]) F[x]=x; else { if(G[x].first<F1[x].first||(G[x].first==F1[x].first&&G[x].second<F1[x].second)) F[x]=G[x].second; if(G[x].first>F1[x].first||(G[x].first==F1[x].first&&G[x].second>F1[x].second)) F[x]=F1[x].second; } for(int k=first[1][x];k!=-1;k=edge[1][k].next) { int y=edge[1][k].y; Dfs2(y,x); } } int Find(int x,int D){for(int i=20;i>=0;i--) if(D>=(1<<i)) D-=(1<<i),x=fa[i][x]; return x;} int ans[Maxn]; void DP(int x) { ans[F[x]]++; unc[x]--; for(int k=first[1][x];k!=-1;k=edge[1][k].next) { int y=edge[1][k].y; int L=Find(y,deep[y]-deep[x]-1); int R=fa[0][y]; int sizex=size[L]; unc[x]-=sizex; int ret=x; if(F[x]!=F[y]) { if(deep[L]<=deep[R]) { while(deep[L]<=deep[R]) { int mid=(deep[L]+deep[R])>>1; int midx=Find(y,deep[y]-mid); int disx=dis(F[x],midx); int disy=dis(F[y],midx); if(disx>disy||(disx==disy&&F[x]>F[y])) R=Find(y,deep[y]-(mid-1)); else if(disx<disy||(disx==disy&&F[x]<F[y])) L=Find(y,deep[y]-(mid+1)),ret=midx; } ans[F[x]]+=size[Find(y,deep[y]-deep[x]-1)]-size[Find(y,deep[y]-deep[ret]-1)]; ans[F[y]]+=size[Find(y,deep[y]-deep[ret]-1)]-size[y]; } } else ans[F[x]]+=size[Find(y,deep[y]-deep[x]-1)]-size[y]; } for(int k=first[1][x];k!=-1;k=edge[1][k].next) { int y=edge[1][k].y; DP(y); } } int B[Maxn]; void Solve() { int K; scanf("%d",&K); for(int i=1;i<=K;i++){scanf("%d",&H[i]); B[i]=H[i]; C[H[i]]=1;} sort(H+1,H+K+1,Cmp); top=1; S[top]=1; plen=1; P[1]=1; for(int i=1;i<=K;i++) { int now=H[i]; int f=lca(now,S[top]); while(dfn[S[top-1]]>dfn[f]) ins(1,S[top-1],S[top],0),top--; if(dfn[S[top]]>dfn[f]) ins(1,f,S[top],0),top--; if(S[top]!=f) S[++top]=f,P[++plen]=f; if(S[top]!=now) S[++top]=now,P[++plen]=now; } while(top>1) ins(1,S[top-1],S[top],0),top--; Dfs1(1); Dfs2(1,0); for(int i=1;i<=plen;i++) unc[P[i]]=size[P[i]]; DP(1); for(int i=1;i<=plen;i++) ans[F[P[i]]]+=unc[P[i]]; for(int i=1;i<=K;i++) printf("%d ",ans[B[i]]); printf("\n"); for(int i=1;i<=K;i++) ans[H[i]]=0; for(int i=1;i<=K;i++) C[H[i]]=0; for(int i=1;i<=plen;i++) F1[P[i]].first=F1[P[i]].second=F2[P[i]].first=F2[P[i]].second=G[P[i]].first=G[P[i]].second=F[P[i]]=inf,first[1][P[i]]=-1; len[1]=0; } int main() { scanf("%d",&N); len[0]=0; memset(first[0],-1,sizeof(first[0])); for(int i=1;i<N;i++){int x,y; scanf("%d%d",&x,&y); ins(0,x,y,1); ins(0,y,x,1);} Dfs(1,0); for(int i=1;i<=20;i++) for(int j=1;j<=N;j++) fa[i][j]=fa[i-1][fa[i-1][j]]; memset(first[1],-1,sizeof(first[1])); len[1]=0; for(int i=0;i<=N;i++) F1[i].first=F1[i].second=F2[i].first=F2[i].second=G[i].first=G[i].second=F[i]=inf; for(int i=1;i<=N;i++) ans[i]=0; scanf("%d",&Q); for(int i=1;i<=Q;i++) Solve(); return 0; } /* 10 2 1 3 2 4 3 5 4 6 1 7 3 8 3 9 4 10 1 5 2 6 1 5 2 7 3 6 9 1 8 4 8 7 10 3 5 2 9 3 5 8 */