BZOJ 3935: Rbtree 树形DP
Description
给定一颗 N 个点的树,树上的每个点或者是红色,或者是黑色。
每个单位时间内,你可以任选两个点,交换它们的颜色。
出于某种恶趣味,你希望用最少的时间调整结点的颜色,使得对于每个点,离它最近的黑色点与它的距离不超过 x。
Input
输入的第一行包含整数 N 和 x(1 <= x <= 10^9)。
接下来一行 N 个整数 C1-Cn,表示结点的初始颜色。1 表示黑色,0 表示红色。
接下来 N-1 行,每行 3 个整数 ui, vi,wi(1 <= wi <= 10^9),表示点 ui 和 vi 之间存在权值为 wi的边。
Output
输出一个数表示答案;如果无解,输出 “-1”。
我们令 $f[i][j][k]$ 表示 $i$ 的子树中有 $j$ 个黑点,且钦定 $i$ 与 $k$ 配对且 $i$ 子树其余点都配对好的最优解.
特别地,我们还要钦定 $k$ 这个点先不进行任何操作.
这道题难就难在这是一个匹配问题,有好多个点都需要去匹配,但是我们通过状态将问题缩小为只匹配一个点.
我们可以这么设置状态的前提是因为我们发现以 $i$ 为根的子树中只有 $i$ 这个点有可能和子树外的点配对.
因为你画一下图的话,你会发现如果 $i$ 的某个子孙要和 $i$ 为根子树外的点配对的话那个点一定和与 $i$ 配对的相等,否则不会更优.
而其余情况,就是在 $i$ 的子树内部自行解决.
我们还钦定不对 $k$ 进行操作的原因是这样就是两个状态的 $k$ 相同,也可以直接相加.
我们需要先观察/挖掘题目中隐含的性质来简化问题,然后再去根据套路设置状态.
#include <cstdio> #include <cstring> #include <algorithm> #define N 508 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int n,K,edges,dfn,m; int v[N],hd[N],to[N<<1],nex[N<<1],val[N<<1],Q[N]; int f[N][N][N],g[N][N],st[N],ed[N],dis[N][N],size[N]; void add(int u,int v,int c) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v,val[edges]=c; } void getdis(int u,int ff,int y) { if(y==1) st[u]=++dfn,Q[dfn]=u; for(int i=hd[u];i;i=nex[i]) if(to[i]!=ff) dis[y][to[i]]=dis[y][u]+val[i],getdis(to[i],u,y); if(y==1) ed[u]=dfn; } void dfs(int x,int ff) { int i,j,k,mn,a; for(a=1;a<=n;++a) if(a!=x&&dis[a][x]<=K) f[x][0][a]=0; f[x][1][x]=0; size[x]=1; for(i=hd[x];i;i=nex[i]) if(to[i]!=ff) { int y=to[i]; dfs(y,x); memset(g,0x3f,sizeof(g)); for(j=0;j<=size[x]&&j<=m;++j) for(k=0;k<=size[y]&&k+j<=m;++k) { mn=0x3f3f3f3f; for(a=st[y];a<=ed[y];++a) mn=min(mn,f[y][k][Q[a]]+!v[Q[a]]); for(a=1;a<=n;++a) g[j+k][a]=min(g[j+k][a],f[x][j][a]+f[y][k][a]); for(a=1;a<st[y];++a) g[j+k][Q[a]]=min(g[j+k][Q[a]],f[x][j][Q[a]]+mn); for(a=ed[y]+1;a<=n;++a) g[j+k][Q[a]]=min(g[j+k][Q[a]],f[x][j][Q[a]]+mn); } size[x]+=size[y]; for(j=0;j<=size[x]&&j<=m;++j) for(a=1;a<=n;++a) f[x][j][a]=g[j][a]; } } int main() { // setIO("input"); int i,j; scanf("%d%d",&n,&K); for(i=1;i<=n;++i) scanf("%d",&v[i]),m+=v[i]; for(i=1;i<n;++i) { int x,y,z; scanf("%d%d%d",&x,&y,&z),add(x,y,z),add(y,x,z); } memset(f,0x3f,sizeof(f)); for(i=1;i<=n;++i) getdis(i,0,i); dfs(1,0); int ans=0x3f3f3f3f; for(i=1;i<=n;++i) ans=min(ans,f[1][m][i]+!v[i]); printf("%d\n",ans>n?-1:ans); return 0; }