BZOJ 3572: [Hnoi2014]世界树 [虚树 DP 倍增]
题意:
一棵树,多次询问,给出$m$个点,求有几个点到给定点最近
写了一晚上...
当然要建虚树了,但是怎么$DP$啊
我们先求出到虚树上某个点最近的关键点
然后枚举所有的边$(f,x)$,讨论一下边上的点的子树应该靠谁更近
倍增求出分界点
注意有些没出现在虚树上的子树
注意讨论的时候只讨论链上的不包括端点,否则$f$的子树会被贡献多次
学到的一些$trick:$
1.$pair$的妙用
2.不需要建出虚树只要求虚树的$dfs$序(拓扑序)和$fa$就可以$DP$了
注意$DP$的时候必须先用儿子更新父亲再用父亲更新儿子,因为父亲的最优值有可能在其他儿子
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <cmath> using namespace std; #define pii pair<int,int> #define MP make_pair #define fir first #define sec second typedef long long ll; const int N=3e5+5,INF=1e9; inline int read(){ char c=getchar();int x=0,f=1; while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();} while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();} return x*f; } int n,Q; struct Edge{ int v,ne,w; }e[N<<1]; int cnt,h[N]; inline void ins(int u,int v){ cnt++; e[cnt].v=v;e[cnt].ne=h[u];h[u]=cnt; cnt++; e[cnt].v=u;e[cnt].ne=h[v];h[v]=cnt; } int fa[N][20],deep[N],dfn[N],dfc,size[N],All; void dfs(int u){ dfn[u]=++dfc; size[u]=1; for(int i=1;(1<<i)<=deep[u];i++) fa[u][i]=fa[ fa[u][i-1] ][i-1]; for(int i=h[u];i;i=e[i].ne) if(e[i].v!=fa[u][0]){ fa[e[i].v][0]=u; deep[e[i].v]=deep[u]+1; dfs(e[i].v); size[u]+=size[e[i].v]; } } inline int lca(int x,int y){ if(deep[x]<deep[y]) swap(x,y); int bin=deep[x]-deep[y]; for(int i=19;i>=0;i--) if((1<<i)&bin) x=fa[x][i]; for(int i=19;i>=0;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return x==y ? x :fa[x][0]; } int a[N],st[N],par[N],dis[N],t[N],m,ans[N]; int remain[N]; inline bool cmp(int x,int y){return dfn[x]<dfn[y];} inline void ins2(int x,int y){par[y]=x;dis[y]=deep[y]-deep[x];} pii g[N]; void dp(int m){ for(int i=m;i>1;i--){ int x=t[i],f=par[x]; g[f]=min(g[f],MP(g[x].fir+dis[x],g[x].sec)); } for(int i=2;i<=m;i++){ int x=t[i],f=par[x]; g[x]=min(g[x],MP(g[f].fir+dis[x],g[f].sec)); } } inline int jump1(int x,int tar){ for(int i=19;i>=0;i--) if(deep[ fa[x][i] ]>=tar) x=fa[x][i]; return x; } inline int jump(int x,int tar){ int bin=deep[x]-tar; for(int i=19;i>=0;i--) if((1<<i)&bin) x=fa[x][i]; return x; } int ora[N]; void solve(){ int n=read(),m=0; for(int i=1;i<=n;i++) ora[i]=a[i]=read(),t[++m]=a[i],g[a[i]]=MP(0,a[i]); sort(a+1,a+1+n,cmp); int top=0; for(int i=1;i<=n;i++){ if(!top) {st[++top]=a[i];continue;} int x=a[i],f=lca(x,st[top]); while(dfn[f]<dfn[st[top]]){ if(dfn[f]>=dfn[st[top-1]]){ ins2(f,st[top--]); if(f!=st[top]) st[++top]=f,t[++m]=f,g[f]=MP(INF,0); break; }else ins2(st[top-1],st[top]),top--; } st[++top]=x; } while(top>1) ins2(st[top-1],st[top]),top--; sort(t+1,t+1+m,cmp); dp(m); for(int i=1;i<=m;i++) remain[t[i]]=size[t[i]]; ans[ g[t[1]].sec ]+=All-size[t[1]]; for(int i=2;i<=m;i++){ int x=t[i],f=par[x];par[x]=0; int t=jump(x,deep[f]+1); remain[f]-=size[t]; if(g[x].sec == g[f].sec) ans[ g[x].sec ]+=size[t]-size[x]; else{ int len=g[x].fir + g[f].fir + dis[x], mid=deep[x]-(len/2-g[x].fir); if( !(len&1) && g[f].sec<g[x].sec ) mid++; int y=jump(x,mid); ans[ g[f].sec ]+=size[t]-size[y]; ans[ g[x].sec ]+=size[y]-size[x]; } } for(int i=1;i<=m;i++) ans[ g[t[i]].sec ]+=remain[t[i]]; for(int i=1;i<=n;i++) printf("%d%c",ans[ora[i]],i==n?'\n':' '),ans[ora[i]]=0; } int main(){ //freopen("in","r",stdin); n=read();All=n; for(int i=1;i<n;i++) ins(read(),read()); dfs(1); Q=read(); while(Q--) solve(); }
Copyright:http://www.cnblogs.com/candy99/