Codeforces 633F 树的直径/树形DP
题意:有两个小孩玩游戏,每个小孩可以选择一个起始点,并且下一个选择的点必须和自己选择的上一个点相邻,问两个选的点权和的最大值是多少?
思路:首先这个问题可以转化为求树上两不相交路径的点权和的最大值,对于这种问题,我们有两种想法:
1:树的直径,受之前HDU多校的那道题的启发,我们先找出树的直径,然后枚举保留直径的哪些部分,去找保留这一部分的最优解,去更新答案。
代码:
#include <bits/stdc++.h> #define INF 1e18 #define LL long long using namespace std; const int maxn = 100010; vector<int> G[maxn]; LL tot; int now, f[maxn]; LL d[maxn], val[maxn], sum[maxn]; bool v[maxn], v1[maxn]; LL mx[maxn]; void add(int x, int y) { G[x].push_back(y); G[y].push_back(x); } void dfs(int x, int fa, LL sum) { sum += val[x]; v1[x] = 1; f[x] = fa; if(sum > tot) { now = x; tot = sum; } for (auto y : G[x]) { if(y == fa || v[y]) continue; dfs(y, x, sum); } } void dfs1(int x, int fa) { d[x] = val[x]; LL tmp = 0; for (auto y : G[x]) { if(y == fa || v[y]) continue; dfs1(y, x); tmp = max(tmp, d[y]); } d[x] += tmp; return; } vector<int> a; int main() { int n, x, y; scanf("%d", &n); for (int i = 1; i <= n; i++) { scanf("%lld", &val[i]); } for (int i = 1; i < n; i++) { scanf("%d%d", &x, &y); add(x, y); } LL ans = 0; tot = 0, now = 0; dfs(1, 0, 0); tot = 0; dfs(now, 0, 0); for (int i = now; i; i = f[i]) { a.push_back(i); sum[a.size()] = sum[a.size() - 1] + val[i]; v[i] = 1; } for (auto y : a) { dfs1(y, 0); } memset(v1, 0, sizeof(v1)); for (int i = 1; i <= n; i++) { if(v[i] == 1 || v1[i] == 1) continue; tot = now = 0; dfs(i, 0, 0); tot = 0; dfs(now, 0, 0); ans = max(ans, tot + sum[a.size()]); } mx[a[a.size() - 1]] = d[a[a.size() - 1]]; for (int i = a.size() - 2; i >= 1; i--) { mx[a[i]] = max(mx[a[i + 1]], sum[a.size()] - sum[i] + d[a[i]] - val[a[i]]); } for (int i = 0; i < a.size() - 1; i++) { ans = max(ans, d[a[i]] + sum[i] + mx[a[i + 1]]); } printf("%lld\n", ans); }
思路2:树形DP,我们对于每个点保留从它到叶子节点的最长路径和次长路径,以及以它为根的子树中的最长路径。dp完之后,我们对于每个节点,枚举选哪棵子树中的节点作为一条路径,然后去寻找另一条最长路径。可以用前缀最大值和后缀最大值去优化,以及需要注意需要考虑父节点方向对答案的影响。
代码:
#include <bits/stdc++.h> #define LL long long #define pll pair<LL, LL> using namespace std; const int maxn = 100010; vector<int> G[maxn]; LL mx_path[maxn], mx_dis[maxn]; LL lmx_path[maxn], rmx_path[maxn]; pll lmx_dis[maxn], rmx_dis[maxn]; LL val[maxn]; LL ans = 0; void add(int x, int y) { G[x].push_back(y); G[y].push_back(x); } void dfs(int x, int fa) { vector<LL> tmp(3); for (auto y : G[x]) { if(y == fa) continue; dfs(y, x); tmp[0] = mx_dis[y]; sort(tmp.begin(), tmp.end()); mx_path[x] = max(mx_path[y], mx_path[x]); } mx_dis[x] = tmp[2] + val[x]; mx_path[x] = max(mx_path[x], tmp[1] + tmp[2] + val[x]); return; } int tot, a[maxn]; struct node { int x, fa; LL mx_p, mx_d; }; queue<node> q; void solve() { q.push((node){1, 0, 0, 0}); while(q.size()) { node tmp = q.front(); q.pop(); tot = 0; int x = tmp.x, fa = tmp.fa; tot = 0; for (int i = 0; i < G[x].size(); i++) { if(G[x][i] == fa) continue; a[++tot] = G[x][i]; } rmx_path[tot + 1] = rmx_path[tot + 2] = 0; rmx_dis[tot + 1] = rmx_dis[tot + 2] = make_pair(0, 0); for (int i = 1; i <= tot; i++) { lmx_path[i] = max(lmx_path[i - 1], mx_path[a[i]]); lmx_dis[i] = lmx_dis[i - 1]; if(mx_dis[a[i]] >= lmx_dis[i].first) { lmx_dis[i].second = lmx_dis[i].first; lmx_dis[i].first = mx_dis[a[i]]; } else if(mx_dis[a[i]] > lmx_dis[i].second) { lmx_dis[i].second = mx_dis[a[i]]; } } for (int i = tot; i >= 1; i--) { rmx_path[i] = max(rmx_path[i + 1], mx_path[a[i]]); rmx_dis[i] = rmx_dis[i + 1]; if(mx_dis[a[i]] >= rmx_dis[i].first) { rmx_dis[i].second = rmx_dis[i].first; rmx_dis[i].first = mx_dis[a[i]]; } else if(mx_dis[a[i]] > rmx_dis[i].second) { rmx_dis[i].second = mx_dis[a[i]]; } } LL tmp1 = 0; // printf("%d\n", x); // for (int i = 1; i <= tot; i++) { // printf("%lld %lld\n", lmx_path[i], rmx_path[i]); // } for (int i = 1; i <= tot; i++) { LL tmp2 = 0; tmp2 = max(tmp2, lmx_path[i - 1]); tmp2 = max(tmp2, rmx_path[i + 1]); tmp2 = max(tmp2, tmp.mx_p); tmp2 = max(tmp2, val[x] + lmx_dis[i - 1].first + tmp.mx_d); tmp2 = max(tmp2, val[x] + rmx_dis[i + 1].first + tmp.mx_d); tmp2 = max(tmp2, val[x] + lmx_dis[i - 1].first + lmx_dis[i - 1].second); tmp2 = max(tmp2, val[x] + rmx_dis[i + 1].first + rmx_dis[i + 1].second); tmp2 = max(tmp2, val[x] + lmx_dis[i - 1].first + rmx_dis[i + 1].first); tmp2 = max(tmp2, val[x] + tmp.mx_d + max(lmx_dis[i - 1].first, rmx_dis[i + 1].first)); // printf("%lld ", tmp.mx_d); // printf("%lld ", tmp2); q.push((node){a[i], x, tmp2, val[x] + max(max(lmx_dis[i - 1].first, rmx_dis[i + 1].first), tmp.mx_d)}); tmp2 += mx_path[a[i]]; // printf("%lld\n", tmp2); tmp1 = max(tmp1, tmp2); } // printf("\n"); ans = max(ans, tmp1); ans = max(ans, tmp.mx_p + mx_path[x]); } } int main() { int n, x, y; // freopen("633Fin.txt", "r" , stdin); // freopen("out1.txt", "w", stdout); scanf("%d" , &n); for (int i = 1; i <= n; i++) scanf("%lld", &val[i]); for (int i = 1; i < n; i++) { scanf("%d%d", &x, &y); add(x, y); } dfs(1, 0); solve(); printf("%lld\n", ans); }