树形 dp
概念
- 在树上做 dp
- 树形 dp 一般是从树的叶子节点向根的做 dp,也就是自下而上做 dp
点击查看代码
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
const int MAXN = 2e5 + 3;
int n, ANS = 0;
int dp[MAXN][3];
vector<int> eg[MAXN];
void dfs(int x, int dad){
for(int nxt : eg[x]){
if(nxt == dad) continue;
dfs(nxt, x);
dp[x][0] += max(dp[nxt][0], dp[nxt][1]);
}
for(int nxt : eg[x]){
if(nxt == dad) continue;
dp[x][1] = max(dp[x][1], dp[x][0] - max(dp[nxt][0], dp[nxt][1]) + dp[nxt][0] + 1);
}
}
int main(){
cin >> n;
for(int i = 1, U, V; i < n; i++){
cin >> U >> V;
eg[U].push_back(V);
eg[V].push_back(U);
}
dfs(1, 0);
cout << max(dp[1][1], dp[1][0]);
return 0;
}
- 需要注意记录最大值、次大值时的细节
- 有时还需要两个 \(pair\) 结合求最大值、次大值,那样细节会更加的多
- 记录最大值、次大值,在做很多树上的 dp 时,都会用到
点击查看代码
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
const int MAXN = 2e5 + 3;
int n, ANS = 0;
int ans[MAXN];
int dp[MAXN][3];
vector<int> eg[MAXN];
void dfs(int x, int dad){
for(int nxt : eg[x]){
if(nxt == dad) continue;
dfs(nxt, x);
if(dp[x][1] <= dp[nxt][1] + 1){
dp[x][2] = dp[x][1];
dp[x][1] = dp[nxt][1] + 1;
}else{
dp[x][2] = max(dp[x][2], dp[nxt][1] + 1);
}
}
//cout << x << " " << dp[x][1] << " " << dp[x][2] << "\n";
ANS = max(ANS, dp[x][1] + dp[x][2]);
}
int main(){
cin >> n;
for(int i = 1, U, V; i < n; i++){
cin >> U >> V;
eg[U].push_back(V);
eg[V].push_back(U);
}
dfs(1, 0);
cout << ANS;
return 0;
}
树形 dp 分类讨论
- 这种题目要多写,可以发现其中的套路
- 一般这种题由两种
- 要推导公式,要差分,主要是数学 例题 1
- 其中的某一个值要么很大,要么很小,从这个突破口来设计状态 例题 2
点击查看例题 1 代码
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
const int MAXN = 1e6 + 3;
const LL mod = 998244353;
int n;
LL dp[MAXN][4];
/*
0: 被删除
1:没有被删除且还没有连点
2: 没有被删除且有连点
*/
vector<int> eg[MAXN];
LL qpow(LL A, LL B){
LL sum = A, ANS = 1;
for(int i = 0; i < 60; i++){
if((B >> i) & 1){
ANS = (ANS * sum) % mod;
}
sum = (sum * sum) % mod;
}
return ANS;
}
void dfs(int x, int dad){
dp[x][0] = dp[x][1] = 1;
for(int nxt : eg[x]){
if(nxt == dad) continue;
dfs(nxt, x);
}
LL sum = 1;
for(int i = 0; i < eg[x].size(); i++){
int nxt = eg[x][i];
if(nxt == dad) continue;
sum *= (dp[nxt][2] + dp[nxt][1] + dp[nxt][0]) % mod;
sum %= mod;
dp[x][0] *= (dp[nxt][2] + dp[nxt][0]) % mod;
dp[x][0] %= mod;
dp[x][1] *= dp[nxt][0] % mod;
dp[x][1] %= mod;
}
dp[x][2] = (sum - dp[x][1] + mod) % mod;
}
void work(){
cin >> n;
for(int i = 1; i <= n; i++) eg[i].clear();
for(int i = 1, U, V; i < n; i++){
cin >> U >> V;
eg[U].push_back(V);
eg[V].push_back(U);
}
dfs(1, 0);
cout << (dp[1][0] + dp[1][2]) % mod << "\n";
}
int main(){
ios::sync_with_stdio(0), cin.tie(0);
int T;
cin >> T;
while(T--){
work();
}
return 0;
}
点击查看例题 2 代码
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
const int MAXN = 1e5 + 3;
int n, t[MAXN];
LL a[MAXN];
LL dp[MAXN][3];
vector<int> eg[MAXN];
void dfs(int x, int dad){
LL sum = 0, MAX1 = 0, MAX2 = 0;
for(int nxt : eg[x]){
if(nxt == dad) continue;
dfs(nxt, x), sum += dp[nxt][1];
if(t[nxt] == 3){
if(a[nxt] >= a[MAX1]){
MAX2 = MAX1, MAX1 = nxt;
}else if(a[nxt] >= a[MAX2]){
MAX2 = nxt;
}
}
}
dp[x][0] = a[x], dp[x][2] = a[x] + sum;
for(int nxt : eg[x]){
if(nxt == dad) continue;
dp[x][0] = max(dp[x][0], dp[nxt][0] + a[x] + (sum - dp[nxt][1]));
int p = (MAX1 == nxt ? MAX2 : MAX1);
if(p > 0){
dp[x][0] = max(dp[x][0], dp[nxt][2] + a[x] + sum - dp[nxt][1] + a[p]);
}
}
dp[x][1] = dp[x][0] - a[x];
}
void work(){
cin >> n;
for(int i = 1; i <= n; i++){
cin >> a[i];
}
for(int i = 1; i <= n; i++){
cin >> t[i], eg[i].clear();;
}
for(int i = 1, U, V; i < n; i++){
cin >> U >> V;
eg[U].push_back(V);
eg[V].push_back(U);
}
dfs(1, 0);
cout << dp[1][0] << "\n";
}
int main(){
ios::sync_with_stdio(0), cin.tie(0);
int T;
cin >> T;
while(T--) work();
return 0;
}