虚树初探
虚树其实没什么的。。
只是因为点太多了不能全开于是只开那些需要用到的点。
一棵虚树包括要求点以及它们的lca。。
虚树的构建。。。(其实感觉如果会虚树的构建的话接下来就是树dp啦没什么的。。。
首先我们应该对整棵树dfs,求出它的dfs序列。然后对于给的点,按dfs排序。。
因为我们是按dfs序排列的,所以虚树一定是由一条条链构成的。。
扫一遍给的点,如果这个点在当前的这条链上,那加在栈顶就可以了。
如果不是的话,那就不断地退栈使的原来的那条链上面的边全部被加到边集中。。
rep(i,1,n){ int t=a[i],f=lca(t,s[top]);; while (top){ if (top>1&&dep[f]<dep[s[top-1]]) insert(s[top-1],s[top],dis(s[top-1],s[top])),top--; else if (dep[f]<dep[s[top]]) {insert(f,s[top],dis(f,s[top])); top--;break;} else break; } if (f!=s[top]) s[++top]=f; s[++top]=t; } while (top>1) insert(s[top-1],s[top],dis(s[top-1],s[top])),top--;
【bzoj3611】[Heoi2014]大工程
两两都要连一条边嘛,记f[i]=i子树上所有点到i这个点的距离和,sz[i]表示i子树上要求点个数。
那么 sum+=(sz[u]*e[j].c+f[u])*sz[v]+f[v]*sz[u] f[u]+=e[j].c*sz[v]+f[v]
#include<cstring> #include<iostream> #include<cstdio> #include<algorithm> #define rep(i,l,r) for (int i=l;i<=r;i++) #define down(i,l,r) for (int i=l;i>=r;i--) #define clr(x,y) memset(x,y,sizeof(x)) #define maxn 1005000 #define ll long long #define inf 1152921504606846976 using namespace std; struct data{int obj,pre;ll c; }e[maxn*2]; int head[maxn],dfn[maxn],dep[maxn],fa[maxn][22],bin[22],a[maxn],b[maxn],bel[maxn],sz[maxn],s[maxn]; ll mx[maxn],mn[maxn],sum[maxn],f[maxn],ans,ans1,ans2; int n,Q,tot,top,idx; void insert(int x,int y,ll z){ e[++tot].obj=y; e[tot].pre=head[x]; head[x]=tot; e[tot].c=z; } int read(){ int x=0,f=1; char ch=getchar(); while (!isdigit(ch)){if (ch=='-') f=-1; ch=getchar();} while (isdigit(ch)){x=x*10+ch-'0'; ch=getchar();} return x*f; } void dfs(int u){ dfn[u]=++idx; rep(i,1,20) if (dep[u]>=bin[i]) fa[u][i]=fa[fa[u][i-1]][i-1]; for (int j=head[u];j;j=e[j].pre){ int v=e[j].obj; if (v!=fa[u][0]) { fa[v][0]=u; dep[v]=dep[u]+1; dfs(v); } } } int lca(int x,int y){ if (dep[x]<dep[y]) swap(x,y); int t=dep[x]-dep[y]; rep(i,0,20) if (t&bin[i]) x=fa[x][i]; down(i,20,0) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; if (x==y) return x; return fa[x][0]; } ll dis(int x,int y){ return 1LL*(dep[x]+dep[y]-2*dep[lca(x,y)]); } bool cmp(int x,int y){ return dfn[x]<dfn[y]; } void dp(int u){ sz[u]=bel[u]; sum[u]=0; f[u]=0; if (bel[u]) mn[u]=0,mx[u]=0; else mn[u]=inf,mx[u]=-inf; for (int j=head[u];j;j=e[j].pre){ int v=e[j].obj; dp(v); ans+=(sz[u]*e[j].c+f[u])*sz[v]+f[v]*sz[u]; f[u]+=e[j].c*sz[v]+f[v]; ans1=min(ans1,mn[u]+e[j].c+mn[v]); ans2=max(ans2,mx[u]+e[j].c+mx[v]); mn[u]=min(mn[u],mn[v]+e[j].c); mx[u]=max(mx[u],mx[v]+e[j].c); sz[u]+=sz[v]; } head[u]=0; } void solve(){ tot=top=idx=0; n=read(); rep(i,1,n) a[i]=b[i]=read(),bel[a[i]]=1; sort(a+1,a+1+n,cmp); if (bel[1]!=1) s[++top]=1; rep(i,1,n){ int t=a[i],f=0; while (top){ f=lca(t,s[top]); if (top>1&&dep[f]<dep[s[top-1]]) insert(s[top-1],s[top],dis(s[top-1],s[top])),top--; else if (dep[f]<dep[s[top]]) {insert(f,s[top],dis(f,s[top])); top--;break;} else break; } if (f!=s[top]) s[++top]=f; s[++top]=t; } while (top>1) insert(s[top-1],s[top],dis(s[top-1],s[top])),top--; ans=0; ans1=inf; ans2=-inf; dp(1); printf("%lld %lld %lld\n",ans,ans1,ans2); rep(i,1,n) bel[b[i]]=0; } int main(){ bin[0]=1; rep(i,1,20) bin[i]=bin[i-1]*2; n=read(); rep(i,1,n-1){ int x=read(),y=read(); insert(x,y,0); insert(y,x,0); } dfs(1); clr(head,0); Q=read(); while (Q--) solve(); return 0; }
3572: [Hnoi2014]世界树
先把虚树建出来,然后跑两边dfs得到虚树上每个点的belong,对于虚树上每条边a,b,可以通过二分得到一个点x使得a到x-1的都让a管,x到b的都让b管。。
#include<cstring> #include<iostream> #include<algorithm> #include<cstdio> #include<queue> #include<set> #include<cmath> #include<vector> #include<map> #define rep(i,l,r) for (int i=l;i<=r;i++) #define down(i,l,r) for (int i=l;i>=r;i--) #define clr(x,y) memset(x,y,sizeof(x)) #define maxn 300500 #define maxm 1005000 #define inf int(1e9) #define mm 1000000007 #define eps 1e-7 typedef long long ll; using namespace std; struct data{int obj,pre; }e[maxn*2]; int head[maxn],c[maxn],dfn[maxn],sz[maxn],fa[maxn][22],bin[22],rem[maxn],s[maxn],dep[maxn],bel[maxn],f[maxn],a[maxn],b[maxn]; int tot,idx,n,top,Q; int read(){ int x=0,f=1; char ch=getchar(); while (!isdigit(ch)){if (ch=='-') f=-1; ch=getchar();} while (isdigit(ch)){x=x*10+ch-'0'; ch=getchar();} return x*f; } void insert(int x,int y){ e[++tot].obj=y; e[tot].pre=head[x]; head[x]=tot; } void dfs(int u){ dfn[u]=++idx; sz[u]=1; rep(i,1,20) if (dep[u]>=bin[i]) fa[u][i]=fa[fa[u][i-1]][i-1]; for (int j=head[u];j;j=e[j].pre){ int v=e[j].obj; if (v!=fa[u][0]){ fa[v][0]=u; dep[v]=dep[u]+1; dfs(v); sz[u]+=sz[v]; } } } int lca(int x,int y){ if (dep[x]<dep[y]) swap(x,y); int t=dep[x]-dep[y]; rep(i,0,20) if (t&bin[i]) x=fa[x][i]; down(i,20,0) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; if (x==y) return x; return fa[x][0]; } int dis(int x,int y){ return dep[x]+dep[y]-2*dep[lca(x,y)]; } void dfs1(int u){ c[++idx]=u; rem[u]=sz[u]; for (int j=head[u];j;j=e[j].pre){ int v=e[j].obj; dfs1(v); if (!bel[v]) continue; int t1=dis(bel[u],u),t2=dis(bel[v],u); if (t1>t2||(t1==t2&&bel[u]>bel[v])||!bel[u]) bel[u]=bel[v]; } } void dfs2(int u){ for (int j=head[u];j;j=e[j].pre){ int v=e[j].obj; int t1=dis(bel[u],v),t2=dis(bel[v],v); if (t1<t2||(t1==t2&&bel[u]<bel[v])||!bel[v]) bel[v]=bel[u]; dfs2(v); } } void solve(int a,int b){ int x=b,mid=b; down(i,20,0) if (dep[fa[x][i]]>dep[a]) x=fa[x][i]; rem[a]-=sz[x]; if (bel[a]==bel[b]) {f[bel[a]]+=sz[x]-sz[b]; return;} down(i,20,0) { int nxt=fa[mid][i]; if (dep[nxt]<=dep[a]) continue; int t1=dis(bel[a],nxt),t2=dis(bel[b],nxt); if (t1>t2||(t1==t2&&bel[b]<bel[a])) mid=nxt; } f[bel[a]]+=sz[x]-sz[mid]; f[bel[b]]+=sz[mid]-sz[b]; } bool cmp(int x,int y){ return dfn[x]<dfn[y]; } void query(){ idx=tot=top=0; n=read(); rep(i,1,n) {a[i]=read(); b[i]=a[i];bel[a[i]]=a[i];} sort(a+1,a+1+n,cmp); if (bel[1]!=1) s[++top]=1; rep(i,1,n){ int t=a[i],f=0; while (top){ f=lca(t,s[top]); if (top>1&&dep[f]<dep[s[top-1]]) {insert(s[top-1],s[top]),top--;} else if (dep[f]<dep[s[top]]) {insert(f,s[top]);top--; break;} else break; } if (f!=s[top]) s[++top]=f; s[++top]=t; } while (top>1) insert(s[top-1],s[top]),top--; dfs1(1); dfs2(1); rep(i,1,idx) for (int j=head[c[i]];j;j=e[j].pre) solve(c[i],e[j].obj); rep(i,1,idx) f[bel[c[i]]]+=rem[c[i]]; rep(i,1,n) printf("%d ",f[b[i]]); puts(""); rep(i,1,idx) head[c[i]]=rem[c[i]]=f[c[i]]=bel[c[i]]=0,c[i]=0; } int main(){ bin[0]=1; rep(i,1,20) bin[i]=bin[i-1]*2; n=read(); rep(i,1,n-1){ int x=read(),y=read(); insert(x,y); insert(y,x); } dfs(1); clr(head,0); Q=read(); while (Q--) query(); return 0; }
2286: [Sdoi2011]消耗战
#include<cstring> #include<cstdio> #include<algorithm> #include<iostream> #include<queue> #include<vector> #include<cmath> #define low(i) (i&(-i)) #define pa pair<int,int> #define rep(i,l,r) for(int i=l;i<=r;i++) #define maxn 250250 #define inf 1152921504606846976 #define mm 12345678910 typedef long long ll; using namespace std; struct data{int obj,pre;ll c; }e[maxn*2],ed[maxn*2]; int st[maxn],h[maxn],mark[maxn],head[maxn],head2[maxn],dep[maxn],top[maxn],s[maxn],son[maxn],fa[maxn]; ll f[maxn],mn[maxn]; int n,m,tt,tot,idx; ll read() { ll x=0, f=1; char ch=getchar(); while (!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();} while (isdigit(ch)) x=x*10+ch-'0', ch=getchar(); return x*f; } void insert(int x,int y,int z){ e[++tot].obj=y; e[tot].c=z; e[tot].pre=head[x]; head[x]=tot; } void insert2(int x,int y){ if (x==y) return; ed[++tot].obj=y; ed[tot].pre=head2[x]; head2[x]=tot; } void dfs(int u){ mark[u]=++idx; s[u]=1; for (int j=head[u];j;j=e[j].pre){ int v=e[j].obj; if (v!=fa[u]){ fa[v]=u; dep[v]=dep[u]+1;mn[v]=min(mn[u],e[j].c); dfs(v); s[u]+=s[v]; if (s[v]>s[son[u]]) son[u]=v; } } } void make(int u,int ff){ top[u]=ff; if (son[u]) make(son[u],ff); for (int j=head[u];j;j=e[j].pre){ int v=e[j].obj; if (v!=son[u]&&v!=fa[u]) make(v,v); } } int lca(int x,int y){ while (top[x]!=top[y]){ if (dep[top[x]]<dep[top[y]]) swap(x,y); x=fa[top[x]]; } return dep[x]<dep[y]?x:y; } bool cmp(int x,int y){ return mark[x]<mark[y]; } void dp(int u){ ll tmp=0; f[u]=mn[u]; for (int j=head2[u];j;j=ed[j].pre){ int v=ed[j].obj; dp(v); tmp+=f[v]; } head2[u]=0; if (tmp==0) f[u]=mn[u]; else f[u]=min(f[u],tmp); } void solve(){ m=read(); rep(i,1,m) h[i]=read(); sort(h+1,h+1+m,cmp); int cnt=0; h[++cnt]=h[1]; rep(i,2,m) { int t=lca(h[cnt],h[i]); if (t!=h[cnt]) h[++cnt]=h[i]; } tot=0; int top=0; st[++top]=1; rep(i,1,cnt){ int now=h[i],t=lca(st[top],now); while (1){ if (dep[t]>=dep[st[top-1]]){ insert2(t,st[top--]); if (st[top]!=t) st[++top]=t; break; } insert2(st[top-1],st[top]); top--; } if (st[top]!=now) st[++top]=now; } while (--top) insert2(st[top],st[top+1]); dp(1); printf("%lld\n",f[1]); } int main(){ n=read(); rep(i,1,n-1){ int x=read(),y=read(); ll z=read(); insert(x,y,z); insert(y,x,z); } tt=read(); dep[1]=1; mn[1]=inf; dfs(1); make(1,1); while (tt--) solve(); return 0; }