【Luogu】P4103大工程(虚树DP)
我貌似发现这类DP就是先别管什么虚树……把树形DP搞出来套上虚树板子就好了
这个树形DP就是设sum为答案,sumd为子树内所有点的深度和(当然指的是被询问的点),maxi指子树内最深的点的深度,mini同理
然后考虑我们dfs到x,它的儿子已经遍历到一半,新加进来一个儿子to
显然$sum[x]+=sum[to]+(sumd[x]-deep[x]*size[x])*size[to]+sumd[to]-deep[x]*size[to]$
$sumd+=sum[to]$
$maxi[x]=max(maxi[x],maxi[to])$
只要注意这类方程的先后顺序即可。
另外注意输出的先后顺序。
#include<cstdio> #include<cstring> #include<cstdlib> #include<cctype> #include<algorithm> #define maxn 1020020 using namespace std; inline long long read(){ long long num=0,f=1; char ch=getchar(); while(!isdigit(ch)){ if(ch=='-') f=-1; ch=getchar(); } while(isdigit(ch)){ num=num*10+ch-'0'; ch=getchar(); } return num*f; } long long sum[maxn],maxi[maxn],mini[maxn]; long long sumd[maxn]; long long ansa[maxn],ansb[maxn]; long long size[maxn]; long long dfn[maxn]; long long deep[maxn]; long long q[maxn]; long long stack[maxn],top; long long s[maxn][22]; bool vis[maxn]; long long ID; struct Pic{ struct Edge{ long long next,to; }edge[maxn*2]; long long head[maxn],num; inline void add(long long from,long long to){ edge[++num]=(Edge){head[from],to}; head[from]=num; } void pre(long long x,long long fa){ dfn[x]=++ID; deep[x]=deep[fa]+1; for(long long i=head[x];i;i=edge[i].next){ long long to=edge[i].to; if(to==fa) continue; s[to][0]=x; pre(to,x); } return; } void dele(int x,int fa){ for(int i=head[x];i;i=edge[i].next){ int to=edge[i].to; if(to==fa) continue; dele(to,x); } head[x]=0; } void dfs(long long x,long long fa){ sum[x]=maxi[x]=sumd[x]=size[x]=0; ansa[x]=0; ansb[x]=0x7fffffff; mini[x]=0x7fffffff; //printf("%lld %lld %lld %d>>>\n",sum[x],x); for(long long i=head[x];i;i=edge[i].next){ long long to=edge[i].to; if(to==fa) continue; //printf("%d %d\n",x,to); dfs(to,x); if(maxi[x]) ansa[x]=max(ansa[x],maxi[x]+maxi[to]-2*deep[x]); if(mini[x]!=0x7fffffff) ansb[x]=min(ansb[x],mini[x]+mini[to]-2*deep[x]); maxi[x]=max(maxi[x],maxi[to]); mini[x]=min(mini[x],mini[to]); sum[x]+=sum[to]; //printf("%lld? %d????????\n",sum[x],x); if(sumd[x]) sum[x]+=(sumd[x]-deep[x]*size[x])*size[to]+(sumd[to]-deep[x]*size[to])*size[x]; //printf("%lld %lld %d %d!!!!!!\n",sum[x],sumd[x],x,to); size[x]+=size[to]; sumd[x]+=sumd[to]; ansa[x]=max(ansa[x],ansa[to]); ansb[x]=min(ansb[x],ansb[to]); } if(vis[x]){ if(sumd[x]) sum[x]+=sumd[x]-(deep[x]*size[x]); sumd[x]+=deep[x]; if(maxi[x]) ansa[x]=max(ansa[x],maxi[x]-deep[x]); if(mini[x]!=0x7fffffff) ansb[x]=min(ansb[x],mini[x]-deep[x]); maxi[x]=max(maxi[x],deep[x]); mini[x]=min(mini[x],deep[x]); size[x]++; } //printf("%lld %d\n",sum[x],x); return; } }old,vir; bool cmp(long long a,long long b){ return dfn[a]<dfn[b]; } inline long long LCA(long long x,long long y){ if(deep[x]<deep[y]) swap(x,y); long long f=deep[x]-deep[y]; for(long long i=0;(1<<i)<=f;++i) if((1<<i)&f) x=s[x][i]; if(x==y) return x; for(long long i=21;i>=0;--i){ if(s[x][i]==s[y][i]) continue; x=s[x][i];y=s[y][i]; } return s[x][0]; } int main(){ long long n=read(); for(long long i=1;i<n;++i){ long long x=read(),y=read(); old.add(x,y); old.add(y,x); } old.pre(1,1); for(long long j=1;j<22;++j) for(long long i=1;i<=n;++i) s[i][j]=s[s[i][j-1]][j-1]; long long m=read(); while(m--){ vir.num=top=0; long long e=read(); //printf("%d>>\n",e); for(long long i=1;i<=e;++i){ q[i]=read(); vis[q[i]]=1; // printf("%d>>>",dfn[q[i]]); } //printf("\n"); sort(q+1,q+e+1,cmp); for(long long i=1;i<=e;++i){ if(top==0){ stack[++top]=q[i]; continue; } long long lca=LCA(q[i],stack[top]); while(dfn[lca]<dfn[stack[top]]){ if(dfn[lca]>=dfn[stack[top-1]]){ vir.add(lca,stack[top]); if(stack[--top]!=lca) stack[++top]=lca; break; } vir.add(stack[top-1],stack[top]); top--; } stack[++top]=q[i]; } while(top>1){ vir.add(stack[top-1],stack[top]); top--; } vir.dfs(stack[1],stack[1]); printf("%lld %lld %lld\n",sum[stack[1]],ansb[stack[1]],ansa[stack[1]]); for(long long i=1;i<=e;++i) vis[q[i]]=0; vir.dele(stack[1],stack[1]); } return 0; }