树形DP 学习笔记

树形DP学习笔记

ps: 本文内容与蓝书一致

树的重心

  • 概念: 一颗树中的一个节点其最大子树的节点树最小
  • 解法:对与每个节点求他儿子的\(size\) ,上方子树的节点个数为\(n-size_u\) ,求对于每个节点子树的最大值,找出最小的那个就好了;

(我觉得就不需要code了)


树的直径

  • 概念:一颗带权树的最长路径
  • 解法:维护一个节点到叶子节点的最大距离\(d1[i]\)和次大距离\(d2[i]\) ,最大距离就是$max {d1[i]+d2[i] } $

code

#include<iostream>
#include<cstdio>
using namespace std;
const int N=1e4+5;
int n;
struct pp
{
    int to,next;
}w[2*N];
int head[N],cnt;
int d1[N],d2[N];
int ans;
void add(int x,int y)
{
    cnt++;
    w[cnt].next=head[x];
    w[cnt].to=y;
    head[x]=cnt;
}
void dfs(int x,int fa)
{
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(t!=fa)
        {
            dfs(t,x);
            if(d1[t]+1>d1[x])
            {
                d2[x]=d1[x];
                d1[x]=d1[t]+1;
            }
            else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
        }
    }
    return ;
}
void find_ans(int x,int fa)
{
    ans=max(ans,d1[x]+d2[x]);
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(t!=fa) find_ans(t,x);
    }
    return;
}
int main()
{
#ifndef ONLINE_JUDGE
    freopen("diam.in","r",stdin);
    freopen("diam.out","w",stdout);
#endif
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    dfs(1,0);
    find_ans(1,0);
    printf("%d",ans);
    return 0;
}

例题

P4480 逃学的小孩

  • 大概思路:求出树的直径以及其左右端点,再设\(d[i]\)为树上节点\(i\)到左右端点距离更小的那个,然后求出\(max \{d[i]\}\),然后以这个值加上直径就是\(ans\)

code

#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
using namespace std;
const int N=2e5+5;
struct pp
{
    int next,to;
    ll qu;
}w[N*2];
int head[N],cnt;
int n,m;
bool v[N];
ll d1[N],d2[N],dl[N],dr[N];
int f1[N],f2[N];
int r,l;
ll ans,mans;
void add(int x,int y,int z)
{
    w[++cnt].next=head[x];
    w[cnt].qu=z;
    w[cnt].to=y;
    head[x]=cnt;
}
int read()
{
    int f=1;
    char ch;
    while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
    int res=ch-'0';
    while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
    return res*f;
}
void dfs1(int x)
{
    if(v[x]) return ;
    v[x]=1;
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(!v[t])
        {
            dfs1(t);
            if(d1[t]+w[i].qu>d1[x])
            {
                f2[x]=f1[x];
                f1[x]=f1[t];
                d2[x]=d1[x];
                d1[x]=d1[t]+w[i].qu;
            }
            else if(d1[t]+w[i].qu>d2[x]) d2[x]=d1[t]+w[i].qu,f2[x]=f1[t];
        }
        
    }
    return;
}
void find_ans(int x)
{
    if(v[x]) return;
    v[x]=1;
    if(ans<d1[x]+d2[x])
    {
        ans=d1[x]+d2[x];
        l=f1[x];
        r=f2[x];
    }
    for(int i=head[x];i;i=w[i].next) find_ans(w[i].to);
}
void dfs2(int x)
{
    if(v[x]) return;
    v[x]=1;
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(!v[t])
        {
            dl[t]=dl[x]+w[i].qu;
            dfs2(t);
        }
    }
    return;
}
void dfs3(int x)
{
    if(v[x])return;
    v[x]=1;
    
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(!v[t])
        {
            dr[t]=dr[x]+w[i].qu;
            dfs3(t);
        }
    }
    return;
}
void dfs_ans(int x)
{
    if(v[x]) return;
    v[x]=1;
    mans=max(mans,min(dl[x],dr[x]));
    for(int i=head[x];i;i=w[i].next) dfs_ans(w[i].to);
    return;
}
int main()
{
#ifndef ONLINE_JUDGE
    freopen("Chris.in","r",stdin);
    freopen("Chris.out","w",stdout);
#endif
    n=read();
    m=read();
    for(int i=1;i<=m;i++)
    {
        int x,y,z;
        x=read(),y=read(),z=read();
        add(x,y,z);
        add(y,x,z);
    }
    for(int i=1;i<=n;i++) f1[i]=i;
    dfs1(1);
    memset(v,0,sizeof(v));
    find_ans(1);
    memset(v,0,sizeof(v));
    dfs2(l);
    memset(v,0,sizeof(v));
    dfs3(r);
    memset(v,0,sizeof(v));
    dfs_ans(1);
    printf("%lld",ans+mans);
    return 0;
}

树的中心

  • 概念:给出一颗带权树,求一个节点,使得此节点到树中其他节点的最远距离最小;

  • 解法:如果是一颗没有负边权的树,那直接找到直径的中点就好;

    但是这里我们考虑有负边权的情况:

    有两种情况:

    1. \(u\)点向上的最长路径,设为\(up[u]\);
    2. \(u\)点向下,即\(u\)到叶节点的最远距离,设为\(d1[u]\)(次远设为\(d2[u]\));

    \(d1[u]\)\(d2[u]\)都会求,问题是\(up[u]\)该怎么求?

    还是分类讨论,设\(u\)的父亲为\(x\),\(d1[x]\)来自于子节点\(v\);那对于\(u\):

    1. 如果\(u!=v\),那么\(up[u]=max\{d1[x],up[x]\}+dis[x][t]\);
    2. 如果\(u==v\),那么\(up[u]=max\{d2[x],up[x]\}+dis[x][t]\),这也是为什么要维护\(d2[x]\)的原因;

code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=1e5+5;
struct pp
{
    int next,to;
}w[2*N];
int n,k;
int head[N],cnt;
int d1[N],d2[N],pre[N],u[N];
int root,far;
int read()
{
    int f=1;
    char ch;
    while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
    int res=ch-'0';
    while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
    return res*f;
}
void add(int x,int y)
{
    cnt++;
    w[cnt].next=head[x];
    w[cnt].to=y;
    head[x]=cnt;
}
bool cmp(int x,int y) {return x>y;}
void dfs1(int x,int fa)
{
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(t!=fa)
        {
            dfs1(t,x);
            if(d1[t]+1>d1[x])
            {
                pre[x]=t;
                d2[x]=d1[x];
                d1[x]=d1[t]+1;
            }
            else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
        }
    }
    return;
}
void dfs2(int x,int fa)
{
    int minx=min(u[x],d1[x]);
    if(far<minx)
    {
        root=x;
        far=minx;
    }
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if (t!=fa)
        {
            if(pre[x]!=t) u[t]=max(d1[x],u[x])+1;
            else u[t]=max(d2[x],u[x])+1;
            dfs2(t,x);
        }
    }
    return ;
}
int main()
{
    n=read(),k=read();
    for(int i=1;i<n;i++)
    {
        int x,y;
        x=read(),y=read();
        add(x,y);
        add(y,x);
    }
    dfs1(1,0);
    dfs2(1,0);
    printf("%d",root);
    return 0;
}

例题

P5536核心城市

  • 思路:显然其中一定会有一个城市为这颗树的中心;那找出这个中心,把这颗无根树变为以它为根的有根树;再求出除根节点以外的每个节点所能到达的最大深度\(deepfar[i]\),这就是这个节点最远所能到达的距离;然后\(sort\)一下\(deepfar[]\),答案就是\(deepfar[k+1]+1\);

code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=1e5+5;
struct pp
{
    int next,to;
}w[2*N];
int n,k;
int head[N],cnt;
int d1[N],d2[N],pre[N],u[N];
int fardeep[N];
int root,far;
int read()
{
    int f=1;
    char ch;
    while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
    int res=ch-'0';
    while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
    return res*f;
}
void add(int x,int y)
{
    cnt++;
    w[cnt].next=head[x];
    w[cnt].to=y;
    head[x]=cnt;
}
bool cmp(int x,int y) {return x>y;}
void dfs1(int x,int fa)
{
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(t!=fa)
        {
            dfs1(t,x);
            if(d1[t]+1>d1[x])
            {
                pre[x]=t;
                d2[x]=d1[x];
                d1[x]=d1[t]+1;
            }
            else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
        }
    }
    return;
}
void dfs2(int x,int fa)
{
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if (t!=fa)
        {
            if(pre[x]!=t) u[t]=max(d1[x],u[x])+1;
            else u[t]=max(d2[x],u[x])+1;
            dfs2(t,x);
        }
    }
    return ;
}
void dfs3(int x,int fa)
{
    int minx=min(u[x],d1[x]);
    if(far<minx)
    {
        root=x;
        far=minx;
    }
    for(int i=head[x];i;i=w[i].next) if(w[i].to!=fa) dfs3(w[i].to,x);
    return;
}
void dfs4(int x,int fa)
{
    for(int i=head[x];i;i=w[i].next)
    {
        int t=w[i].to;
        if(t!=fa)
        {
            dfs4(w[i].to,x);
            fardeep[x]=max(fardeep[x],fardeep[t]+1);
        }
    }
}
int main()
{
#ifndef ONLINE_JUDGE
    freopen("XR-3.in","r",stdin);
    freopen("XR-3.out","w",stdout);
#endif
    n=read(),k=read();
    for(int i=1;i<n;i++)
    {
        int x,y;
        x=read(),y=read();
        add(x,y);
        add(y,x);
    }
    dfs1(1,0);
    dfs2(1,0);
    dfs3(1,0);
    dfs4(root,0);
    sort(fardeep+1,fardeep+1+n,cmp);
    printf("%d",fardeep[k+1]+1);
    return 0;
}

上面都是有关树的一些经典题型,下面才是今天的主角——树型DP


背包类树型DP

(我觉得把,其实左右子树类树型DP可以归为这一类)

例题

选课

书上的是时间复杂度为\(n^3\)的算法,这里介绍一个优化,可以讲其降为\(n^2\);

  • 泛化物品优化:具体是什么,请参考2009年国家集训队论文——徐持衡《浅谈几类背包问题》,其中有详细解释;

  • 而我对泛化物品优化的感性理解就是:"预留空间"——为在 \(u\) 到到根节点的路径上(包括u)的点预留空间。

    这样就可以在对 \(u\)DP的时候保证他所依赖的物品预先算进去了

    \(dp[u][j]\)的意思就是在预留\(u\)及其到根节点的路径上的点的空间后,还剩下\(j\)的空间的最大价值;

  • 没有优化前,DP方程为:

  • 没有优化前,DP方程为:

\[dp[u][j]=max\{dp[u][j],dp[u][j-k]+dp[v][k]\} \]

这样对于每个节点都要\(n^2\)暴力枚举\(j\)\(k\);

经过优化,我们的DP方程就变为了:

\[\begin{cases} dp[v][j]=dp[u][j](dfs前)\\ dp[u][j]=max\{dp[u][j],dp[v][j-w[v]]+val[v]\}(回溯时) \end{cases} \]

这也是再泛化物品优化下,树型背包的基本DP方程;这样我们只需要\(O(n)\)枚举\(j\)就好了;


ps: 以下代码参考价值不大,建议参考[HAOI2010]软件安装

code

#include<iostream>
#include<algorithm>
#include<queue>
#include<cstdio>
#include<cstring>
using namespace std;

int n,m;
struct edge
{
    int next,to;
}e[1000];
int rt,head[1000],tot,val[1000],dp[1000][1000];
void add(int x,int y)
{
    e[++tot].next=head[x];
    head[x]=tot;
    e[tot].to=y;
}
void dfs(int u,int t)
{
    if (t<=0) return ;
    for (int i=head[u]; i; i=e[i].next)
    {
        int v = e[i].to;
        for (int j=0; j<=t-1; ++j) //为v预留空间
            dp[v][j] = dp[u][j];
        dfs(v,t-1);//对于v的现有空间
        for (int j=1; j<=t; ++j) 
            dp[u][j] = max(dp[u][j],dp[v][j-1]+val[v]);//背包
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
    {
        int a;
        scanf("%d%d",&a,&val[i]);
        if(a)
          add(a,i);
        if(!a)add(0,i);
    }
    dfs(0,m);
    printf("%d",dp[0][m]);
}

选择类树型DP

基本DP方程:

\[v\in{son(u)} \begin{cases} f[u][0]=\sum f[v][1] \\ f[u][1]=min\{f[v][1],f[v][0]\}+1 \end{cases} \]

例题

P2016战略游戏

直接套DP方程就好了;

code

#include<iostream>
#include<cstdio>
using namespace std;
int n;
int dp[1605][2];
struct pp
{
	int next,to;
}w[1600<<1];
int head[1600],cnt;
void add(int x,int y)
{
	cnt++;
	w[cnt].to=y;
	w[cnt].next=head[x];
	head[x]=cnt;
}
void dfs(int x,int fa)
{
	dp[x][1]=1;
	for(int i=head[x];i;i=w[i].next)
	{
		int t=w[i].to;
		if(t==fa) continue;
		dfs(t,x);
		dp[x][0]+=dp[t][1];
		dp[x][1]+=min(dp[t][0],dp[t][1]);
	}
	return;
}
int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++)
	{
		int a,k;
		scanf("%d%d",&a,&k);
		for(int i=1;i<=k;i++)
		{
			int b;
			scanf("%d",&b);
			add(a,b);
			add(b,a);
		}
	}
	dfs(0,0);
	printf("%d",min(dp[0][1],dp[0][0]));
	return 0;
}

普通树型DP

这种树型DP更加灵活,就不像前两种有基本固定的DP方程,所以还是直接来几道例题;(滑稽

例题

LOJ #10157. 皇宫看守

乍一看题,啊哈,模板选择树型DP,开开心心打个代码,恭喜你0分;

仔细一看这道题其实不是什么没有上司的舞会,而是一道覆盖DP题,区别在哪呢?

这道题一条边两端至少要有一个点,可以有两个,而没有上司我舞会是一条边两端至多有一个点,可以没有;

那好,这样的话一个节点u的最少经费就不能像选择DP一样单纯的由儿子选不选的而转移过来,因为他们本来互不冲突,而是必须被覆盖到(这里每个节点的覆盖半径为1),这样对于一个节点u的最少经费就可以由覆盖它的节点转移过来,这样的话就需要考虑三种情况:

首先设\(dp[u][0]\)表示被节点\(u\)被父亲覆盖且\(u\)不选,\(dp[u][1]\)表示被自己的子节点覆盖且\(u\)不选,\(dp[u][2]\)表示被自己覆盖;

所以有状态转移方程:

  • 对于\(dp[u][0]\),因为\(u\)不选,所以对于\(u\)的子节点\(v\),要么被\(son(v)\)所覆盖,要么被\(v\)自己覆盖:

\[dp[u][0]=\sum min\{dp[v][1],dp[v][2]\} +a[f[u]]; \]

  • 对于\(dp[u][1]\),要保证\(u\)必须被一个子节点所覆盖到,还要保证\(u\)的子节点\(v\)在不被父亲覆盖的前提下被覆盖到,那显然\(dp[u][1]\),是由\(dp[v][1]\)\(dp[v][2]\)转移过来的,但是如何保证\(dp[u][1]\)的转移中一定包含\(dp[v][2]\)呢?

    这时候有个巧妙的办法,设个参数:

    \[d=min\{d,dp[v][2]-min\{dp[v][1],dp[v][2]\}\} \]

    \(d\)的初始值为\(0x7fffffff\);

    这样对于\(dp[u][1]\)就有状态转移方程:

    \[dp[u][1]=\sum min\{dp[v][1],dp[v][2]\}+d \]

  • 对于\(dp[u][2]\),那很显然它可以由子节点任意三种状态转移过来,但是对于\(dp[v][0]\),它已经加过一遍\(a[u]\),而对于\(dp[u][2]\),只能且必须加一遍\(a[u]\),那怎么办呢?单独特判由\(dp[v][0]\)转移过来的情况,控制\(a[u]\)只加一遍?显然是可以的,但是太麻烦了,那么另外考虑,这里可以看到\(dp[v][0]\)只会往\(dp[u][2]\)上转移,那么可以根据\(dp[u][2]\)需求对\(dp[v][0]\)状态转移方程改一改:

    \[dp[u][0]=\sum min\{dp[v][1],dp[v][2]\} \]

    (这里的\(u\)是对于\(v\)来说的)

    感性理解一下就是如果\(dp[u][2]\)不由\(dp[v][0]\)转移过来那要\(dp[v][0]\)也没有什么用,那由\(dp[v][0]\)转移过来,那在\(dp[u][2]\)这加一遍\(a[u]\)就够了,因为\(dp[u][2]\)已经保证了\(u\)被选,所以不需要\(dp[v][0]\)再保证一遍;

    这样对于\(dp[u][2]\),就有状态转移方程:

    \[dp[u][2]=\sum min\{dp[v][1],dp[v][2],dp[v][0]\} +a[u] \]

总结下来就有三个状态转移方程:

\[\begin{cases} dp[u][0]=\sum min\{dp[v][1],dp[v][2]\};\\ dp[u][1]=\sum min\{dp[v][1],dp[v][2]\}+d ;(d=min\{d,dp[v][2]-min\{dp[v][1],dp[v][2]\}\})\\ dp[u][2]=\sum min\{dp[v][1],dp[v][2],dp[v][0]\} +a[u] \end{cases} \]

(所以,显然书上的状态转移方程是错的)

不难发现,修改后的\(dp[v][0]\)一定小于等于\(dp[v][1]\);所以写代码的时候我顺手把\(dp[u][2]\)的转移方程改成了:

\[dp[u][2]=\sum min\{dp[v][2],dp[v][0]\} +a[u] \]

虽然题目早已经解决了,但我还是想再深究一下;这个方程啥意思?

以我的感性理解就是\(v\)既然已经一定会被它爹\(u\)覆盖到,那就可以不需要保证\(v\)一定被它的儿子所覆盖,修改后的\(dp[v][0]\)刚好就是这种情况;

(好了,bb了这么多废话,就一点有用的东西,直接上代码吧)

code

#include <iostream>
#include <cstdio>
using namespace std;
const int N = 1500 + 5;
int dp[N][3];
int v[N], n, root;
struct pp {
    int next, to;
} w[N];
int head[N], cnt, du[N];
void add(int x, int y) {
    cnt++;
    w[cnt].next = head[x];
    w[cnt].to = y;
    head[x] = cnt;
}
void dfs(int x) {
    int d = 0x7fffffff;
    for (int i = head[x]; i; i = w[i].next) {
        int t = w[i].to;
        dfs(t);
        dp[x][0] += min(dp[t][1], dp[t][2]);
        dp[x][1] += min(dp[t][1], dp[t][2]);
        d = min(d, dp[t][2] - min(dp[t][1], dp[t][2]));
        dp[x][2] += min(dp[t][2], dp[t][0]);
    }
    dp[x][1] += d;
    dp[x][2] += v[x];
}
int main() {
#ifndef ONLINE_JUDGE
    freopen("guard.in", "r", stdin);
    freopen("guard.out", "w", stdout);
#endif
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        int x, m;
        scanf("%d", &x);
        scanf("%d", &v[x]);
        scanf("%d", &m);
        for (int j = 1; j <= m; j++) {
            int y;
            scanf("%d", &y);
            add(x, y);
            du[y]++;
        }
    }
    for (int i = 1; i <= n; i++)
        if (!du[i])
            root = i;
    dfs(root);
    printf("%d", min(dp[root][1], dp[root][2]));
    return 0;
}

好了,差不多就结束了,虽然写这个一点耗时,但对于我这个蒟蒻来说加深了对于DP的理解,收获也不小,也不算浪费时间了吧(逃);


PS: 2020.10.9 添加了我对泛化物品优化的理解

posted @ 2020-01-18 17:42  zfz04  阅读(387)  评论(4编辑  收藏  举报