三色树——需要深度思考的树形dp
三色树
给出一个N个节点的无根树,每条边有非负边权,每个节点有三种颜色:黑,白,灰。
一个合法的无根树满足:树中不含有黑色结点或者含有至多一个白色节点。
现在希望你通过割掉几条树边,使得形成的若干树合法,并最小化割去树边权值的和。
第一行一个正整数N,表示树的节点个数。
第二行N个整数Ai,表示i号节点的颜色,0 表示黑色,1表示白色,2表示灰色。
接下来N-1行每行三个整数Xi Yi Zi,表示一条连接Xi和Yi权为Zi的边。
输出一个整数表示其最小代价。
5
0 1 1 1 0
1 2 5
1 3 3
5 2 5
2 4 16
10
样例解释:
花费10的代价删去边(1, 2)和边(2, 5)。
20%的数据满足N≤10。
另外30%的数据满足N≤100,000,且保证树是一条链。
100%的数据满足N≤300,000,0≤Zi≤1,000,000,000,Ai∈{0,1,2}。
分析:
其实明眼人都能看出这是树形dp,可是当我们仔细去思考该维护什么时,我们就陷进去了。因为我们所想的任何方法的维护都十分的复杂,很容易给人一种思路错了的错觉。可是这题确确实实就是这么复杂,很多人想得到方向,却无法深入,接下来分析一下。
我们要维护的是3类情况(用f来表示):
1.f[i][0]表示以i为根节点的子树在切割后不含黑点的最小代价;
2.f[i][1]表示以i为根节点的子树在切割后不含白点的最小代价;
3.f[i][2]表示以i为根节点的子树在切割后含一个白点的最小代价;
按这样分之后答案就是根节点s三个值的最小值
接下来考虑转移方程:
1.当col[i](即该点颜色)=0时,明显不符合无黑点,所以f[i][0]=inf(无穷大);而当col[i]!=0时,这时要考虑断边情况,很容易可以列得
$f[i][0]= \sum_{son}$min(f[son][0],min(f[son][1],f[son][2])+w)<--w为边权
2.当col[i]=1时,同样的不符合无白点,所以f[i][1]=inf;而当col[i]!=1时,这时要考虑断边情况,一样可以列得
$f[i][1]= \sum_{son}$min(f[son][1],min(f[son][0],f[son][2])+w) (2与1几乎一模一样)
3.当col[i]=1时,这时该点已经是一个白点了,所以方程式和2情况的第二种一样。剩下最后一种情况(最复杂的一种),即col[i]!=1时,这时我们直接处理的话会列出一长串,但我们可以和f[i][1]结合起来,我们只要从最后算出的f[i][1]中减去一种min(f[son][1],min(f[son][0],f[son][2])+w)再加上f[son][2]就能维护一个白点的情况,同样的,我们要使f[i][2]最小化,而最终的
f[i][1]又是个定值,所以我们要最大化min(f[son][1],min(f[son][0],f[son][2])+w)+f[son][2].
所以最终方程式为:$f[i][2]= f[i][1](最终的)-max(min(f[son][1],min(f[son][0],f[son][2])+w)-f[son][2]).<--减号是因为加了个括号
代码:
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<cmath> 5 #include<queue> 6 using namespace std; 7 #define debug printf("zjyvegetable\n") 8 #define int long long 9 inline int read(){ 10 int a=0,b=1;char c=getchar(); 11 while(!isdigit(c)){if(c=='-')b=-1;c=getchar();} 12 while(isdigit(c)){a=a*10+c-'0';c=getchar();} 13 return a*b; 14 } 15 const int N=4e5+50,M=2e6+50,inf=123456789012345678; 16 int n,col[N],tot,vis[N],h[N],ver[M],nx[M],ed[M],f[N][3],t[M],top; 17 void add(int u,int v,int z){ 18 ver[++tot]=v;ed[tot]=z; 19 nx[tot]=h[u];h[u]=tot; 20 } 21 inline void dfs(int x){ 22 vis[x]=1; 23 if(col[x]==0)f[x][0]=inf; 24 else if(col[x]==1)f[x][1]=inf; 25 int maxn=0; 26 for(int i=h[x];i;i=nx[i]){ 27 int v=ver[i]; 28 if(vis[v])continue; 29 dfs(v); 30 if(col[x]==0){ 31 f[x][1]+=min(f[v][1],min(f[v][0],f[v][2])+ed[i]); 32 maxn=max(maxn,min(f[v][1],ed[i]+min(f[v][0],f[v][2]))-f[v][2]); 33 } 34 else if(col[x]==1){ 35 f[x][0]+=min(f[v][0],min(f[v][1],f[v][2])+ed[i]); 36 f[x][2]+=min(f[v][1],ed[i]+min(f[v][0],f[v][2])); 37 } 38 else{ 39 f[x][0]+=min(f[v][0],min(f[v][1],f[v][2])+ed[i]); 40 f[x][1]+=min(f[v][1],min(f[v][0],f[v][2])+ed[i]); 41 maxn=max(maxn,min(f[v][1],ed[i]+min(f[v][0],f[v][2]))-f[v][2]); 42 } 43 } 44 if(col[x]!=1)f[x][2]=f[x][1]-maxn; 45 } 46 signed main(){ 47 //freopen("tree2.in","r",stdin); 48 //freopen("tree2.out","w",stdout); 49 int u,v,z; 50 n=read(); 51 for(int i=1;i<=n;i++){ 52 col[i]=read(); 53 } 54 for(int i=1;i<n;i++){ 55 u=read();v=read(); 56 z=read(); 57 add(u,v,z);add(v,u,z); 58 } 59 dfs(1); 60 printf("%lld\n",min(f[1][0],min(f[1][1],f[1][2]))); 61 return 0; 62 }
留在最后的话:
由于这题数据有长链,所以要用人工栈,而笔者由于懒而只打了dfs,拿不到全部分,望理解。
后来笔者在某题被迫学了人工栈,所以回来填坑了,这里补一下人工栈部分代码:
void fake_dfs(int begin){ sta[++top]=begin; while(top){ int x=sta[top]; if(!vis[x]){ if(col[x]==0)f[x][0]=inf; else if(col[x]==1)f[x][1]=inf; for(int i=h[x];i;i=nx[i]){ int v=ver[i]; if(v==fa[x])continue; fa[v]=x; sta[++top]=v; } vis[x]=1; } else{ int maxn=0;top--; for(int i=h[x];i;i=nx[i]){ int v=ver[i]; if(v==fa[x])continue; if(col[x]==0){ f[x][1]+=min(f[v][1],min(f[v][0],f[v][2])+ed[i]); maxn=max(maxn,min(f[v][1],ed[i]+min(f[v][0],f[v][2]))-f[v][2]); } else if(col[x]==1){ f[x][0]+=min(f[v][0],min(f[v][1],f[v][2])+ed[i]); f[x][2]+=min(f[v][1],ed[i]+min(f[v][0],f[v][2])); } else{ f[x][0]+=min(f[v][0],min(f[v][1],f[v][2])+ed[i]); f[x][1]+=min(f[v][1],min(f[v][0],f[v][2])+ed[i]); maxn=max(maxn,min(f[v][1],ed[i]+min(f[v][0],f[v][2]))-f[v][2]); } } if(col[x]!=1)f[x][2]=f[x][1]-maxn; } } }