虚树
对于每次询问,虚树由询问的特殊点以及它们的LCA组成
将两点之间的链的信息压缩成一条链
可知所有的LCA即DFS序相邻的特殊点的LCA
BZOJ2286(错误写法)
#include <cstdio> #include <iostream> #include <algorithm> #define LL long long using namespace std; int dfn[250001],next[500001],cnt,nd[250001],des[500001],dep[250001],b[250001]; int fa[250001][21],n,a[500001],sta[250001]; LL len[500001],f[250001],mini[250001][21]; int mycomp(int x,int y){ return(dfn[x]<dfn[y]); } void addedge(int x,int y,LL le){ next[++cnt]=nd[x];des[cnt]=y;len[cnt]=le;nd[x]=cnt; } void dfs(int po){ dfn[po]=++cnt; for (int p=nd[po];p!=-1;p=next[p]) if (dep[des[p]]==0){ dep[des[p]]=dep[po]+1; fa[des[p]][0]=po; mini[des[p]][0]=len[p]; dfs(des[p]); } } void LCA_init(){ for (int i=1;i<=20;i++) for (int j=1;j<=n;j++) fa[j][i]=fa[fa[j][i-1]][i-1], mini[j][i]=min(mini[j][i-1],mini[fa[j][i-1]][i-1]); } int getlca(int x,int y,int typ){ LL ret=1e18; if (dep[x]<dep[y]){ int t=x;x=y;y=t; } for (int i=20;i>=0;i--) if (dep[fa[x][i]]>=dep[y]) ret=min(ret,mini[x][i]),x=fa[x][i]; for (int i=20;i>=0;i--) if (fa[x][i]!=fa[y][i]) ret=min(ret,mini[x][i]),x=fa[x][i], ret=min(ret,mini[y][i]),y=fa[y][i]; if (x!=y) ret=min(ret,mini[x][0]),x=fa[x][0], ret=min(ret,mini[y][0]),y=fa[y][0]; if (typ) return(ret);else return(x); } void dp(int po){ f[po]=0; for (int p=nd[po];p!=-1;p=next[p]){ dp(des[p]); if (b[des[p]]) f[po]+=len[p];else f[po]+=min(len[p],f[des[p]]); } } int main(){ scanf("%d",&n); for (int i=1;i<=n;i++) nd[i]=-1; for (int i=1;i<n;i++){ int t1,t2;LL t3; scanf("%d%d%lld",&t1,&t2,&t3); addedge(t1,t2,t3);addedge(t2,t1,t3); } cnt=0; dep[1]=1; dfs(1); LCA_init(); int q; scanf("%d",&q); while (q--){ int k; scanf("%d",&k); for (int i=1;i<=k;i++) scanf("%d",&a[i]); sort(a+1,a+k+1,mycomp); for (int i=1;i<k;i++) a[i+k]=getlca(a[i],a[i+1],0);a[2*k]=1; for (int i=k+1;i<=2*k;i++) b[a[i]]=0; for (int i=1;i<=k;i++) b[a[i]]=1; sort(a+1,a+2*k+1,mycomp); a[0]=-1e9; for (int i=1;i<=2*k;i++) nd[a[i]]=-1; int top=0;cnt=0; for (int i=1;i<=2*k;i++) if (a[i]!=a[i-1]){ while (top&&getlca(sta[top],a[i],0)!=sta[top]) top--; if (top) addedge(sta[top],a[i],getlca(sta[top],a[i],1)); sta[++top]=a[i]; } dp(1); printf("%lld\n",f[1]); } }
正确写法为按DFN递增插入节点,对上一个插入节点与当前节点LCA是否为上一个节点分类讨论
是则直接插入,否则插入LCA退栈
如下(BZOJ3611)
#include <cstdio> #include <iostream> #include <algorithm> #define LL long long using namespace std; int dfn[1000001],next[2000001],des[2000001],nd[1000001],len[2000001],cnt; int dep[1000001],fa[1000001][21],n,b[1000001],maxi[1000001],mini[1000001],k; int a[1000001],sta[1000001],ansmin,ansmax; LL f[1000001],g[1000001],ans; int mycomp(int x,int y){ return(dfn[x]<dfn[y]); } void addedge1(int x,int y){ next[++cnt]=nd[x];des[cnt]=y;nd[x]=cnt; next[++cnt]=nd[y];des[cnt]=x;nd[y]=cnt; } void addedge(int x,int y,int le){ next[++cnt]=nd[x];des[cnt]=y;len[cnt]=le;nd[x]=cnt; } void dfs(int po){ dfn[po]=++cnt; for (int p=nd[po];p!=-1;p=next[p]) if (!dep[des[p]]){ dep[des[p]]=dep[po]+1; fa[des[p]][0]=po; dfs(des[p]); } } void LCA_ini(){ for (int i=1;i<=20;i++) for (int j=1;j<=n;j++) fa[j][i]=fa[fa[j][i-1]][i-1]; } int getlca(int x,int y){ if (dep[x]<dep[y]){ int t=x;x=y;y=t; } for (int i=20;i>=0;i--) if (dep[fa[x][i]]>=dep[y]) x=fa[x][i]; for (int i=20;i>=0;i--) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; if (x==y) return(x);else return(fa[x][0]); } int dis(int x,int y){ return(dep[x]+dep[y]-2*dep[getlca(x,y)]); } void dp(int po){ if (b[po]){ f[po]++;maxi[po]=0;mini[po]=0; } for (int p=nd[po];p!=-1;p=next[p]){ dp(des[p]); ansmax=max(ansmax,maxi[po]+maxi[des[p]]+len[p]); maxi[po]=max(maxi[des[p]]+len[p],maxi[po]); ansmin=min(ansmin,mini[po]+mini[des[p]]+len[p]); mini[po]=min(mini[des[p]]+len[p],mini[po]); ans+=(g[po]+(LL)f[po]*(LL)len[p])*(LL)f[des[p]]+(LL)g[des[p]]*(LL)f[po]; f[po]+=f[des[p]]; g[po]+=g[des[p]]+(LL)len[p]*f[des[p]]; } if (b[po]) ansmax=max(ansmax,maxi[po]); if (b[po]&&mini[po]>0) ansmin=min(ansmin,mini[po]); } void solve(){ scanf("%d",&k); cnt=0; int nowk=k; for (int i=1;i<=k;i++) scanf("%d",&a[i]),nd[a[i]]=-1;nd[1]=-1; sort(a+1,a+nowk+1,mycomp); int top; sta[top=1]=1; for (int i=1;i<=k;i++){ if (getlca(a[i],sta[top])==sta[top]) {if (a[i]!=sta[top]) sta[++top]=a[i];continue;} int otop=top; while (getlca(a[i],sta[top])!=sta[top]) top--; for (int j=top+1;j<otop;j++) addedge(sta[j],sta[j+1],dis(sta[j],sta[j+1])); int newin=getlca(a[i],sta[top+1]); if (newin==sta[top]){ addedge(sta[top],sta[top+1],dis(sta[top+1],sta[top])); sta[++top]=a[i]; }else{ nd[newin]=-1; addedge(newin,sta[top+1],dis(sta[top+1],newin)); sta[++top]=a[++nowk]=newin; sta[++top]=a[i]; } } for (int i=2;i<=top;i++) addedge(sta[i-1],sta[i],dis(sta[i],sta[i-1])); a[++nowk]=1; for (int i=1;i<=nowk;i++) { maxi[a[i]]=-1e9,mini[a[i]]=1e9,f[a[i]]=g[a[i]]=0; } for (int i=k+1;i<=nowk;i++) b[a[i]]=0; for (int i=1;i<=k;i++) b[a[i]]=1; ans=0;ansmin=1e9;ansmax=-1e9; dp(1); printf("%lld %d %d\n",ans,ansmin,ansmax); } int main(){ scanf("%d",&n); for (int i=1;i<=n;i++) nd[i]=-1; for (int i=1;i<n;i++){ int t1,t2; scanf("%d%d",&t1,&t2); addedge1(t1,t2); } dep[1]=1;cnt=0; dfs(1); LCA_ini(); int q; scanf("%d",&q); while (q--) solve(); }