树形dp例题 + 学习笔记(入门版)

树形dp,即在树上进行dp。

需要对树这一数据结构有清晰的了解,还需要学会树的遍历。

难点常常在于状态方程的书写。

例题

例题都来自https://www.luogu.com.cn/training/214#problems

一、没有上司的舞会

题意
树上每个结点有权值,要求在树上选一些点,满足有父子关系的结点只能出现一个,问选出的最大的权值和。

思路
dp[i][0/1] 表示 第 i 号结点选或者不选,令 xson[i] ,方程为

dp[i][0]+=max(dp[x][0],dp[x][1]) ,

dp[i][1]+=dp[x][0]

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 6e3+10;
int t, n, m;
int a[N];
vector<int>e[N];
ll dp[N][2], ans;

void dfs(int now,int fa){
    dp[now][1]=a[now];
    for(auto i:e[now]){
        if(i==fa) continue;
        //从下往上做,否则求得的是一条链 
        dfs(i,now); 
        //上司不去,下属去或不去都可以 
        dp[now][0] += max(dp[i][0],dp[i][1]);
        //上司去,下属不去 
        dp[now][1] += dp[i][0]; 
    }
    ans = max(ans, dp[now][0]);
    ans = max(ans, dp[now][1]);
}

int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>a[i];
    }
    for(int i=1,u,v;i<n;i++){
        scanf("%d%d",&u,&v);
        e[u].pb(v);
        e[v].pb(u);
    }
    dfs(1,0);
    printf("%lld\n",ans);
    system("pause");
    return 0;
}

二、二叉苹果树

题意
一棵树,如果有分叉一定是二叉,每根枝条连接了一些苹果,问你从根开始保留m条枝条最多有多少果子

思路
分组背包。

dp[i][j] 表示以 i 为根,保留 j 根枝条能得到的最大果子数,xson[i]

dp[i][j]=max(dp[i][j],dp[i][jk]+dp[x][k1]+w[i][x])

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 305, inf = 0x3f3f3f3f;
int t, n, m, q, x;
int cnt, head[N], v[N];
struct Edge{
	int to, nex;
    int w;
}e[2*N];
int k, a, c;
int ans;
int dp[N][N], sz[N];

void add(int u,int v, int w){
	e[++cnt].to=v;
	e[cnt].nex=head[u];
    e[cnt].w = w;
	head[u]=cnt;
} 

void dfs1 (int now, int fa) {
    for(int i = head[now]; i; i = e[i].nex){
        int x = e[i].to;
        if(x == fa) continue;
        dfs1(x, now);
        sz[now] += sz[x];
    }
}

void dfs (int now, int fa) {
    // int s = 1;
    for (int i = head[now]; i; i = e[i].nex) {
        int x = e[i].to;
        if(x == fa) continue;
        dfs(x, now);
        // s += sz[x];
        for (int j = m; j >= 1; j--) {
            for (int k = 1; k <= j; k++){
                if(dp[now][j-k] != -inf && dp[x][k-1] != -inf)
                    dp[now][j] = max(dp[now][j], dp[now][j-k] + e[i].w + dp[x][k-1]);
            }
        }
    }
}

int main(){
    scanf("%d%d",&n, &m);
    for (int i = 1; i < n; i++) {
        scanf("%d%d%d",&a, &c, &q);
        add(a, c, q);
        add(c, a, q);
    }
    memset(dp,-inf,sizeof(dp));
    for (int i = 0; i <= n; i++) dp[i][0] = 0, sz[i] = 1;
    
    dfs1(1, 1);

    dfs(1, 1);

    printf("%d\n",dp[1][m]);

    system ("pause");
    return 0;
}

三、选课

题意
每门课有学分,同时他们之间还有选择的先后关系,问选m门课可以获得的最大学分是多少。

思路
思路同例题二,为分组背包

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 305;
int t, n, m, q, x;
int cnt, head[N], v[N];
struct Edge{
	int to, nex;
    int w;
}e[2*N];
int k, a, c;
int ans;
int dp[N][N], sz[N];

void add(int u,int v){
	e[++cnt].to=v;
	e[cnt].nex=head[u];
	head[u]=cnt;
} 

void dfs1 (int now, int fa) {
    for(int i = head[now]; i; i = e[i].nex){
        int x = e[i].to;
        if(x == fa) continue;
        dfs1(x, now);
        sz[now] += sz[x];
    }
}

void dfs (int now, int fa) {
    dp[now][1] = v[now];
    int s = 1;
    for (int i = head[now]; i; i = e[i].nex) {
        int x = e[i].to;
        if(x == fa) continue;
        dfs(x, now);
        s += sz[x];
        for (int j = min(s, m+1); j >= 0; j--) {
            for (int k = 1; k < j ; k++) {
                dp[now][j] = max(dp[now][j], dp[now][j - k] + dp[x][k]);
            }
        }
    }
}

int main(){
    scanf("%d%d",&n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%d%d",&a, &c);
        v[i] = c;
        if (a) {
            add(a, i);
        }
        else {
            add(0, i);
        }
    }
    memset(dp,-0x3f3f3f3f,sizeof(dp));
    for (int i = 0; i <= n; i++) dp[i][0] = 0, sz[i] = 1;
    
    dfs1(0, 0);

    dp[0][1] = 0;
    dfs(0, 0);

    printf("%d\n",dp[0][m+1]);

    system ("pause");
    return 0;
}

四、跑路

题意
给定一张有向图,每条路长度都是1,如果从a到b地有长度为 2k 的道路,那么时间为1,问从1到n所需的最短时间

思路
倍增 + floyd

首先预处理出所有长度为 2k 的道路,然后跑一遍floyd即可求出每两点间的最短路

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 55;
int t,n,m;
int G[N][N][111], f[N][N];

void floyd(){
    memset(f, 0x3f3f3f3f, sizeof(f));
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= n; j++) {
            for (int k = 0; k <= 64; k++) {
                if (G[i][j][k]) {
                    f[i][j] = 1;
                    break;
                }
            }
        }
    }
    for (int z = 1; z <= n; z++) {
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= n; j++) {
                f[i][j] = min(f[i][j], f[i][z] + f[z][j]);
            }
        }
    }
    printf("%d\n",f[1][n]);
}

int main(){
    scanf("%d%d", &n, &m);
    for (int i = 1, u, v; i <= m; i++) { 
        scanf("%d%d", &u, &v);
        G[u][v][0] = 1;
        // G[v][u][0] = 1;
    }
    for (int k = 1; k <= 64; k++) {
        for (int z = 1; z <= n; z++) {
            for (int i = 1; i <= n; i++) {
                for (int j = 1; j <= n; j++) {
                    if (G[i][z][k-1] && G[z][j][k-1]) G[i][j][k] = 1;
                }
            }
        }
    }
    floyd();

    system("pause");
    return 0;
}

五、采蘑菇

题意
给定一张有向图,每条路上有一些数量的蘑菇,如果重复经过,当前数量*=权重系数,直到数量为0.求从s点出发最多能收集到多少蘑菇

思路
因为有一些边是可以重复经过的(处于环中的),所以首先进行缩点,处理出缩成的点含有的蘑菇数,再在DAG中dfs

tips

  • 缩点需要用到tarjan/kosaraju
  • 怎么处理缩点之后每个点的内部权值?
    ——可以通过再遍历所有的边。细节见代码
for (int i = 1; i <= n; i++) {
  for (int j = head[i]; j; j = e[j].next) {
    if (f[i] == f[e[j].to]) {
	int tem = e[j].v;
	while (tem) {
	  val[f[i]] += tem;  ///
	  tem = tem * e[j].p / 10;
	}
    }
    else {
      v[f[i]].push_back({f[e[j].to], e[j].v});
    }
  }
}
  • 怎么在DAG中dfs?
    ——dp

dp[fa]=max { dp[son]+val[fa] }

点击查看代码
#include<bits/stdc++.h>    
#define pii pair<int,int> 
#define ll long long      
using namespace std;
const int N = 80005,M = 200005;
int n, m, s;
vector<pii>v[N];
bool in[N], vis[N];
int cnt,t;
int dfn[N],low[N],sta[N], f[N], val[N];
int x[M], y[M], w[M], dp[N];
double p[M];
struct edge{
	int v,next,to;
    double p;
};
edge e[M];
int head[N];

inline void add(int u,int v,int d, double p){
	cnt++;
	e[cnt].to=v;
	e[cnt].v=d;
	e[cnt].next=head[u];
    e[cnt].p = p;
	head[u]=cnt;
}

void tarjan(int now){  //本质是dfs 
	dfn[now]=low[now]=++cnt;
	sta[++t]=now;   //借助数据结构栈实现 
	in[now]=1;
	for (int i = head[now]; i; i = e[i].next) {
		int x = e[i].to;
		if(!dfn[x]){
			tarjan(x);
			low[now]=min(low[now],low[x]); //在访问x的过程中,可能遇到后向边,使x更新low值 
		}
		else{
			if(in[x]){   //如果不在栈中,表示x和now没有父子关系 ,可以无视 
				low[now]=min(low[now],dfn[x]);
			}
		}
	}
	
	if(dfn[now]==low[now]){
		int cur;
		do{
			cur=sta[t];
			f[cur]=now;
			in[cur]=0;
			t--;
		}while(now!=cur);
	}
} 

void dfs(int now){
    if (vis[now]) return;
	vis[now] = 1;
	int mx = 0;
	for (auto i:v[now]) {
		dfs(i.first);
		mx = max(mx, dp[i.first] + i.second);
	}
	dp[now] = mx + val[now];
}

int main(){
	cin>>n>>m;
		for(int i=1;i<=m;i++){
			scanf("%d%d%d%lf", &x[i], &y[i], &w[i], &p[i]);
			p[i] *= 10;
			add(x[i], y[i], w[i], p[i]);
		}

		for(int i=1;i<=n;i++){
			if(!dfn[i]){
				tarjan(i);
			}
		}

		for (int i = 1; i <= n; i++) {
			for (int j = head[i]; j; j = e[j].next) {
				if (f[i] == f[e[j].to]) {
					int tem = e[j].v;
					while (tem) {
						val[f[i]] += tem;  ///
						tem = tem * e[j].p / 10;
					}
				}
				else{
					v[f[i]].push_back({f[e[j].to], e[j].v});
				}
			}
		}

        scanf("%d", &s);
        dfs(f[s]);
		printf("%d\n", dp[f[s]]);

	system("pause");
	return 0;
}

六、加分二叉树

思路

树形dp:让每个点都当一次根,求出最大值

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 35, inf = 0x3f3f3f3f;
int t,n,m;
ll dp[N][N], ans;
int root[N][N];

ll dfs(int l, int r){
    if (l > r) return 1;
    if (dp[l][r] != -inf) return dp[l][r];
    ll tem = 0;
    for (int i = l; i <= r; i++) {
        tem = dfs(l, i-1) * dfs(i+1, r) + dp[i][i];
        if (tem > dp[l][r]) {
            dp[l][r] = tem;
            root[l][r] = i;
        }
    }
    return dp[l][r];
}

void print(int l, int r){
    if (l > r) return;
    printf("%d ",root[l][r]);
    print(l, root[l][r] - 1);
    print(root[l][r] + 1, r);
}

int main(){
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) 
        for (int j = 1; j <= n; j++) 
            dp[i][j] = -inf;
    for (int i = 1; i <= n; i++) { scanf("%lld", &dp[i][i]), root[i][i] = i; }
    printf("%lld\n", dfs(1,n));
    print(1, n); puts("");
    system("pause");
    return 0;
}

七、三色二叉树

题意
每个点只可以被染色成红or绿or蓝色,父子颜色必须不同,如果是二叉的,父亲、左儿子右儿子颜色必须都不同。求一棵树中最多和最少被染色成绿色的结点个数

思路
dp[i][0/1/2] 表示结点i被染成绿/红/蓝时子树中最多有多少个绿色的结点。最少同理

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 5e5 + 10;
int t,n,m;
char s[N];
int dp[N][3], f[N][3];  //当前结点被染色成0/1/2 时 子树中最多/最少有多少绿点
int dfn;

void dfs (int x) {
    if (s[x] == '0') {
        dp[x][0] = 1;
        f[x][0] = 1;
        //其他是0
        return;
    }
    dfs(++dfn);
    if (s[x] == '1') {
        dp[x][0] = max(dp[x + 1][1], dp[x + 1][2]) + 1;
        dp[x][1] = max(dp[x + 1][0], dp[x + 1][2]);
        dp[x][2] = max(dp[x + 1][0], dp[x + 1][1]);

        f[x][0] = min(f[x + 1][1], f[x + 1][2]) + 1;
        f[x][1] = min(f[x + 1][0], f[x + 1][2]);
        f[x][2] = min(f[x + 1][0], f[x + 1][1]);
    }
    else{
        int k = ++dfn;
        dfs(k);
        dp[x][0] = max(dp[x + 1][1] + dp[k][2], dp[x + 1][2] + dp[k][1]) + 1;
        dp[x][1] = max(dp[x + 1][0] + dp[k][2], dp[x + 1][2] + dp[k][0]);
        dp[x][2] = max(dp[x + 1][1] + dp[k][0], dp[x + 1][0] + dp[k][1]);

        f[x][0] = min(f[x + 1][1] + f[k][2], f[x + 1][2] + f[k][1]) + 1;
        f[x][1] = min(f[x + 1][0] + f[k][2], f[x + 1][2] + f[k][0]);
        f[x][2] = min(f[x + 1][1] + f[k][0], f[x + 1][0] + f[k][1]);
    }
}   

int main(){
    scanf("%s", s+1);
    dfs(++dfn);
    int ans1 = max(dp[1][0], max(dp[1][1], dp[1][2]));
    int ans2 = min(f[1][0], min(f[1][1], f[1][2]));
    printf("%d %d\n", ans1, ans2);
    system("pause");
    return 0;
}
posted @   starlightlmy  阅读(66)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!
点击右上角即可分享
微信分享提示