树形 dp

树形 dp

概念

  • 在树上做 dp
  • 树形 dp 一般是从树的叶子节点向根的做 dp,也就是自下而上做 dp

树上 dp 加差分统计

  • 记住差分,在做很多树上的统计题时,都会用到
点击查看代码
#include <bits/stdc++.h>

using namespace std;
using LL = long long;

const int MAXN = 2e5 + 3;

int n, ANS = 0;
int dp[MAXN][3];
vector<int> eg[MAXN];

void dfs(int x, int dad){
  for(int nxt : eg[x]){
    if(nxt == dad) continue;
    dfs(nxt, x);
    dp[x][0] += max(dp[nxt][0], dp[nxt][1]);
  }
  for(int nxt : eg[x]){
    if(nxt == dad) continue;
    dp[x][1] = max(dp[x][1], dp[x][0] - max(dp[nxt][0], dp[nxt][1]) + dp[nxt][0] + 1);
  }
}

int main(){
  cin >> n;
  for(int i = 1, U, V; i < n; i++){
    cin >> U >> V;
    eg[U].push_back(V);
    eg[V].push_back(U);
  }
  dfs(1, 0);
  cout << max(dp[1][1], dp[1][0]);
  return 0;
}

树上 dp 记录最大值、次大值

  • 需要注意记录最大值、次大值时的细节
  • 有时还需要两个 \(pair\) 结合求最大值、次大值,那样细节会更加的多
  • 记录最大值、次大值,在做很多树上的 dp 时,都会用到
点击查看代码
#include <bits/stdc++.h>

using namespace std;
using LL = long long;

const int MAXN = 2e5 + 3;

int n, ANS = 0;
int ans[MAXN];
int dp[MAXN][3];
vector<int> eg[MAXN];

void dfs(int x, int dad){
  for(int nxt : eg[x]){
    if(nxt == dad) continue;
    dfs(nxt, x);
    if(dp[x][1] <= dp[nxt][1] + 1){
      dp[x][2] = dp[x][1];
      dp[x][1] = dp[nxt][1] + 1;
    }else{
      dp[x][2] = max(dp[x][2], dp[nxt][1] + 1);
    }
  }
  //cout << x << " " << dp[x][1] << " " << dp[x][2] << "\n";
  ANS = max(ANS, dp[x][1] + dp[x][2]);
}

int main(){
  cin >> n;
  for(int i = 1, U, V; i < n; i++){
    cin >> U >> V;
    eg[U].push_back(V);
    eg[V].push_back(U);
  }
  dfs(1, 0);
  cout << ANS;
  return 0;
}

树形 dp 分类讨论

  • 这种题目要多写,可以发现其中的套路
  • 一般这种题由两种
    • 要推导公式,要差分,主要是数学 例题 1
    • 其中的某一个值要么很大,要么很小,从这个突破口来设计状态 例题 2
点击查看例题 1 代码
#include <bits/stdc++.h>

using namespace std;
using LL = long long;

const int MAXN = 1e6 + 3;
const LL mod = 998244353;

int n;
LL dp[MAXN][4];
/*
0: 被删除
1:没有被删除且还没有连点
2: 没有被删除且有连点
*/
vector<int> eg[MAXN];

LL qpow(LL A, LL B){
  LL sum = A, ANS = 1;
  for(int i = 0; i < 60; i++){
    if((B >> i) & 1){
      ANS = (ANS * sum) % mod;
    }
    sum = (sum * sum) % mod;
  }
  return ANS;
}

void dfs(int x, int dad){
  dp[x][0] = dp[x][1] = 1;
  for(int nxt : eg[x]){
    if(nxt == dad) continue;
    dfs(nxt, x);
  }
  LL sum = 1;
  for(int i = 0; i < eg[x].size(); i++){
    int nxt = eg[x][i];
    if(nxt == dad) continue;
    sum *= (dp[nxt][2] + dp[nxt][1] + dp[nxt][0]) % mod;
    sum %= mod;
    dp[x][0] *= (dp[nxt][2] + dp[nxt][0]) % mod;
    dp[x][0] %= mod;
    dp[x][1] *= dp[nxt][0] % mod;
    dp[x][1] %= mod;
  }
  dp[x][2] = (sum - dp[x][1] + mod) % mod;
}

void work(){
  cin >> n;
  for(int i = 1; i <= n; i++) eg[i].clear();
  for(int i = 1, U, V; i < n; i++){
    cin >> U >> V;
    eg[U].push_back(V);
    eg[V].push_back(U);
  }
  dfs(1, 0);
  cout << (dp[1][0] + dp[1][2]) % mod << "\n";
}

int main(){
  ios::sync_with_stdio(0), cin.tie(0);
  int T;
  cin >> T;
  while(T--){
    work();
  }
  return 0;
}
点击查看例题 2 代码
#include <bits/stdc++.h>

using namespace std;
using LL = long long;

const int MAXN = 1e5 + 3;

int n, t[MAXN];
LL a[MAXN];
LL dp[MAXN][3];
vector<int> eg[MAXN];

void dfs(int x, int dad){
  LL sum = 0, MAX1 = 0, MAX2 = 0;
  for(int nxt : eg[x]){
    if(nxt == dad) continue;
    dfs(nxt, x), sum += dp[nxt][1];
    if(t[nxt] == 3){
      if(a[nxt] >= a[MAX1]){
        MAX2 = MAX1, MAX1 = nxt;
      }else if(a[nxt] >= a[MAX2]){
        MAX2 = nxt;
      }
    }
  }
  dp[x][0] = a[x], dp[x][2] = a[x] + sum;
  for(int nxt : eg[x]){
    if(nxt == dad) continue;
    dp[x][0] = max(dp[x][0], dp[nxt][0] + a[x] + (sum - dp[nxt][1]));
    int p = (MAX1 == nxt ? MAX2 : MAX1);
    if(p > 0){
      dp[x][0] = max(dp[x][0], dp[nxt][2] + a[x] + sum - dp[nxt][1] + a[p]);
    }
  }
  dp[x][1] = dp[x][0] - a[x];
}

void work(){
  cin >> n;
  for(int i = 1; i <= n; i++){
    cin >> a[i];
  }
  for(int i = 1; i <= n; i++){
    cin >> t[i], eg[i].clear();;
  }
  for(int i = 1, U, V; i < n; i++){
    cin >> U >> V;
    eg[U].push_back(V);
    eg[V].push_back(U);
  }
  dfs(1, 0);
  cout << dp[1][0] << "\n";
}

int main(){
  ios::sync_with_stdio(0), cin.tie(0);
  int T;
  cin >> T;
  while(T--) work();
  return 0;
}
posted @ 2023-08-15 19:47  hhhqx  阅读(16)  评论(0编辑  收藏  举报