树形dp例题 + 学习笔记(入门版)
树形dp,即在树上进行dp。
需要对树这一数据结构有清晰的了解,还需要学会树的遍历。
难点常常在于状态方程的书写。
例题
例题都来自https://www.luogu.com.cn/training/214#problems
题意
树上每个结点有权值,要求在树上选一些点,满足有父子关系的结点只能出现一个,问选出的最大的权值和。
思路
用 表示 第 号结点选或者不选,令 ,方程为
,
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 6e3+10;
int t, n, m;
int a[N];
vector<int>e[N];
ll dp[N][2], ans;
void dfs(int now,int fa){
dp[now][1]=a[now];
for(auto i:e[now]){
if(i==fa) continue;
//从下往上做,否则求得的是一条链
dfs(i,now);
//上司不去,下属去或不去都可以
dp[now][0] += max(dp[i][0],dp[i][1]);
//上司去,下属不去
dp[now][1] += dp[i][0];
}
ans = max(ans, dp[now][0]);
ans = max(ans, dp[now][1]);
}
int main(){
cin>>n;
for(int i=1;i<=n;i++){
cin>>a[i];
}
for(int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
e[u].pb(v);
e[v].pb(u);
}
dfs(1,0);
printf("%lld\n",ans);
system("pause");
return 0;
}
题意
一棵树,如果有分叉一定是二叉,每根枝条连接了一些苹果,问你从根开始保留m条枝条最多有多少果子
思路
分组背包。
令 表示以 为根,保留 根枝条能得到的最大果子数,
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 305, inf = 0x3f3f3f3f;
int t, n, m, q, x;
int cnt, head[N], v[N];
struct Edge{
int to, nex;
int w;
}e[2*N];
int k, a, c;
int ans;
int dp[N][N], sz[N];
void add(int u,int v, int w){
e[++cnt].to=v;
e[cnt].nex=head[u];
e[cnt].w = w;
head[u]=cnt;
}
void dfs1 (int now, int fa) {
for(int i = head[now]; i; i = e[i].nex){
int x = e[i].to;
if(x == fa) continue;
dfs1(x, now);
sz[now] += sz[x];
}
}
void dfs (int now, int fa) {
// int s = 1;
for (int i = head[now]; i; i = e[i].nex) {
int x = e[i].to;
if(x == fa) continue;
dfs(x, now);
// s += sz[x];
for (int j = m; j >= 1; j--) {
for (int k = 1; k <= j; k++){
if(dp[now][j-k] != -inf && dp[x][k-1] != -inf)
dp[now][j] = max(dp[now][j], dp[now][j-k] + e[i].w + dp[x][k-1]);
}
}
}
}
int main(){
scanf("%d%d",&n, &m);
for (int i = 1; i < n; i++) {
scanf("%d%d%d",&a, &c, &q);
add(a, c, q);
add(c, a, q);
}
memset(dp,-inf,sizeof(dp));
for (int i = 0; i <= n; i++) dp[i][0] = 0, sz[i] = 1;
dfs1(1, 1);
dfs(1, 1);
printf("%d\n",dp[1][m]);
system ("pause");
return 0;
}
题意
每门课有学分,同时他们之间还有选择的先后关系,问选m门课可以获得的最大学分是多少。
思路
思路同例题二,为分组背包
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 305;
int t, n, m, q, x;
int cnt, head[N], v[N];
struct Edge{
int to, nex;
int w;
}e[2*N];
int k, a, c;
int ans;
int dp[N][N], sz[N];
void add(int u,int v){
e[++cnt].to=v;
e[cnt].nex=head[u];
head[u]=cnt;
}
void dfs1 (int now, int fa) {
for(int i = head[now]; i; i = e[i].nex){
int x = e[i].to;
if(x == fa) continue;
dfs1(x, now);
sz[now] += sz[x];
}
}
void dfs (int now, int fa) {
dp[now][1] = v[now];
int s = 1;
for (int i = head[now]; i; i = e[i].nex) {
int x = e[i].to;
if(x == fa) continue;
dfs(x, now);
s += sz[x];
for (int j = min(s, m+1); j >= 0; j--) {
for (int k = 1; k < j ; k++) {
dp[now][j] = max(dp[now][j], dp[now][j - k] + dp[x][k]);
}
}
}
}
int main(){
scanf("%d%d",&n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d%d",&a, &c);
v[i] = c;
if (a) {
add(a, i);
}
else {
add(0, i);
}
}
memset(dp,-0x3f3f3f3f,sizeof(dp));
for (int i = 0; i <= n; i++) dp[i][0] = 0, sz[i] = 1;
dfs1(0, 0);
dp[0][1] = 0;
dfs(0, 0);
printf("%d\n",dp[0][m+1]);
system ("pause");
return 0;
}
题意
给定一张有向图,每条路长度都是1,如果从a到b地有长度为 的道路,那么时间为1,问从1到n所需的最短时间
思路
倍增 + floyd
首先预处理出所有长度为 的道路,然后跑一遍floyd即可求出每两点间的最短路
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 55;
int t,n,m;
int G[N][N][111], f[N][N];
void floyd(){
memset(f, 0x3f3f3f3f, sizeof(f));
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
for (int k = 0; k <= 64; k++) {
if (G[i][j][k]) {
f[i][j] = 1;
break;
}
}
}
}
for (int z = 1; z <= n; z++) {
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
f[i][j] = min(f[i][j], f[i][z] + f[z][j]);
}
}
}
printf("%d\n",f[1][n]);
}
int main(){
scanf("%d%d", &n, &m);
for (int i = 1, u, v; i <= m; i++) {
scanf("%d%d", &u, &v);
G[u][v][0] = 1;
// G[v][u][0] = 1;
}
for (int k = 1; k <= 64; k++) {
for (int z = 1; z <= n; z++) {
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
if (G[i][z][k-1] && G[z][j][k-1]) G[i][j][k] = 1;
}
}
}
}
floyd();
system("pause");
return 0;
}
题意
给定一张有向图,每条路上有一些数量的蘑菇,如果重复经过,当前数量*=权重系数,直到数量为0.求从s点出发最多能收集到多少蘑菇
思路
因为有一些边是可以重复经过的(处于环中的),所以首先进行缩点,处理出缩成的点含有的蘑菇数,再在DAG中dfs
tips
- 缩点需要用到tarjan/kosaraju
- 怎么处理缩点之后每个点的内部权值?
——可以通过再遍历所有的边。细节见代码
for (int i = 1; i <= n; i++) {
for (int j = head[i]; j; j = e[j].next) {
if (f[i] == f[e[j].to]) {
int tem = e[j].v;
while (tem) {
val[f[i]] += tem; ///
tem = tem * e[j].p / 10;
}
}
else {
v[f[i]].push_back({f[e[j].to], e[j].v});
}
}
}
- 怎么在DAG中dfs?
——dp
{ }
点击查看代码
#include<bits/stdc++.h>
#define pii pair<int,int>
#define ll long long
using namespace std;
const int N = 80005,M = 200005;
int n, m, s;
vector<pii>v[N];
bool in[N], vis[N];
int cnt,t;
int dfn[N],low[N],sta[N], f[N], val[N];
int x[M], y[M], w[M], dp[N];
double p[M];
struct edge{
int v,next,to;
double p;
};
edge e[M];
int head[N];
inline void add(int u,int v,int d, double p){
cnt++;
e[cnt].to=v;
e[cnt].v=d;
e[cnt].next=head[u];
e[cnt].p = p;
head[u]=cnt;
}
void tarjan(int now){ //本质是dfs
dfn[now]=low[now]=++cnt;
sta[++t]=now; //借助数据结构栈实现
in[now]=1;
for (int i = head[now]; i; i = e[i].next) {
int x = e[i].to;
if(!dfn[x]){
tarjan(x);
low[now]=min(low[now],low[x]); //在访问x的过程中,可能遇到后向边,使x更新low值
}
else{
if(in[x]){ //如果不在栈中,表示x和now没有父子关系 ,可以无视
low[now]=min(low[now],dfn[x]);
}
}
}
if(dfn[now]==low[now]){
int cur;
do{
cur=sta[t];
f[cur]=now;
in[cur]=0;
t--;
}while(now!=cur);
}
}
void dfs(int now){
if (vis[now]) return;
vis[now] = 1;
int mx = 0;
for (auto i:v[now]) {
dfs(i.first);
mx = max(mx, dp[i.first] + i.second);
}
dp[now] = mx + val[now];
}
int main(){
cin>>n>>m;
for(int i=1;i<=m;i++){
scanf("%d%d%d%lf", &x[i], &y[i], &w[i], &p[i]);
p[i] *= 10;
add(x[i], y[i], w[i], p[i]);
}
for(int i=1;i<=n;i++){
if(!dfn[i]){
tarjan(i);
}
}
for (int i = 1; i <= n; i++) {
for (int j = head[i]; j; j = e[j].next) {
if (f[i] == f[e[j].to]) {
int tem = e[j].v;
while (tem) {
val[f[i]] += tem; ///
tem = tem * e[j].p / 10;
}
}
else{
v[f[i]].push_back({f[e[j].to], e[j].v});
}
}
}
scanf("%d", &s);
dfs(f[s]);
printf("%d\n", dp[f[s]]);
system("pause");
return 0;
}
思路
树形dp:让每个点都当一次根,求出最大值
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 35, inf = 0x3f3f3f3f;
int t,n,m;
ll dp[N][N], ans;
int root[N][N];
ll dfs(int l, int r){
if (l > r) return 1;
if (dp[l][r] != -inf) return dp[l][r];
ll tem = 0;
for (int i = l; i <= r; i++) {
tem = dfs(l, i-1) * dfs(i+1, r) + dp[i][i];
if (tem > dp[l][r]) {
dp[l][r] = tem;
root[l][r] = i;
}
}
return dp[l][r];
}
void print(int l, int r){
if (l > r) return;
printf("%d ",root[l][r]);
print(l, root[l][r] - 1);
print(root[l][r] + 1, r);
}
int main(){
scanf("%d", &n);
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
dp[i][j] = -inf;
for (int i = 1; i <= n; i++) { scanf("%lld", &dp[i][i]), root[i][i] = i; }
printf("%lld\n", dfs(1,n));
print(1, n); puts("");
system("pause");
return 0;
}
题意
每个点只可以被染色成红or绿or蓝色,父子颜色必须不同,如果是二叉的,父亲、左儿子右儿子颜色必须都不同。求一棵树中最多和最少被染色成绿色的结点个数
思路
用 表示结点i被染成绿/红/蓝时子树中最多有多少个绿色的结点。最少同理
点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pb push_back
using namespace std;
const int N = 5e5 + 10;
int t,n,m;
char s[N];
int dp[N][3], f[N][3]; //当前结点被染色成0/1/2 时 子树中最多/最少有多少绿点
int dfn;
void dfs (int x) {
if (s[x] == '0') {
dp[x][0] = 1;
f[x][0] = 1;
//其他是0
return;
}
dfs(++dfn);
if (s[x] == '1') {
dp[x][0] = max(dp[x + 1][1], dp[x + 1][2]) + 1;
dp[x][1] = max(dp[x + 1][0], dp[x + 1][2]);
dp[x][2] = max(dp[x + 1][0], dp[x + 1][1]);
f[x][0] = min(f[x + 1][1], f[x + 1][2]) + 1;
f[x][1] = min(f[x + 1][0], f[x + 1][2]);
f[x][2] = min(f[x + 1][0], f[x + 1][1]);
}
else{
int k = ++dfn;
dfs(k);
dp[x][0] = max(dp[x + 1][1] + dp[k][2], dp[x + 1][2] + dp[k][1]) + 1;
dp[x][1] = max(dp[x + 1][0] + dp[k][2], dp[x + 1][2] + dp[k][0]);
dp[x][2] = max(dp[x + 1][1] + dp[k][0], dp[x + 1][0] + dp[k][1]);
f[x][0] = min(f[x + 1][1] + f[k][2], f[x + 1][2] + f[k][1]) + 1;
f[x][1] = min(f[x + 1][0] + f[k][2], f[x + 1][2] + f[k][0]);
f[x][2] = min(f[x + 1][1] + f[k][0], f[x + 1][0] + f[k][1]);
}
}
int main(){
scanf("%s", s+1);
dfs(++dfn);
int ans1 = max(dp[1][0], max(dp[1][1], dp[1][2]));
int ans2 = min(f[1][0], min(f[1][1], f[1][2]));
printf("%d %d\n", ans1, ans2);
system("pause");
return 0;
}
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!