Codeforces 1179D 树形DP 斜率优化
题意:给你一颗树,你可以在树上添加一条边,问添加一条边之后的简单路径最多有多少条?简单路径是指路径中的点只没有重复。
思路:添加一条边之后,树变成了基环树。容易发现,以基环上的点为根的子树的点中的简单路径没有增加。所以,问题相当于转化为找一个基环,使得以基环上的点为根的子树Σ(i从1到n) sz[i] * (sz[i] - 1) / 2最小。我们把式子转化一下变成求(sz[i]的平方和 - n) / 2。相当于我们需要求sz[i]的平方和。但是,我们并不知道哪个是基环,怎么求sz呢?我们发现一个性质:添加的边连接的两点一定是树中度数为1的点,否则,我们一定可以缩小平方和。所以,根据这个性质,我们可以进行树形dp。设dp[i]为以i为根的子树中,选择从i到子树中的某个叶子节点的路径为基环上的点,可以获得的最小的平方和。dp[i] = min(dp[son] + (sz[i] - sz[son]) ^ 2)。
我们假设选择的基环是u -> lca(u, v) -> v ,假设fu为u到lca(u, v)的路径中lca(u, v)的前面一个节点,fv同理,那么平方和为ans = dp[fu] + dp[fv] + (n - sz[fu] - sz[fv]) ^ 2。所以,我们在深搜的时候,找到所有孩子的dp值和sz,枚举是哪两个孩子来更新平方和,这样最坏情况是O(n ^ 2)的,会超时。发现状态转移方程中有fu和fv的乘积项,我们可以考虑斜率优化。把方程移项: dp[fv] = 2 * (n - sz[fu]) * sz[fv] + (ans - dp[fu] - 2 * n * sz[fu])。那么相当于是以sz[fv]为横坐标,dp[fv]为纵坐标,斜率为2 * (n - sz[fu])的直线,要ans最小,需要截距最小。我们把sz从小到大排序,用单调队列维护一个下凸包,之后在单调队列里二分即可。注意的细节:1,二分之后需要判断合不合法,不能fu和fv相等了。2:斜率优化只考虑的fu和fv不等的情况,我们需要特判一下从最优的叶子结点直接连到当前结点的这种情况。
代码:
#pragma comment(linker, "/stack:200000000") #include <bits/stdc++.h> #define LL long long #define pll pair<LL, LL> using namespace std; const int maxn = 500010; vector<int> G[maxn]; LL sz[maxn], dp[maxn]; pll q[maxn], a[maxn], b[maxn]; int l, r; LL n, ans; int tot; map<pll, int> mp; void add(int x, int y) { G[x].push_back(y); G[y].push_back(x); } bool check(pll x, pll y, pll z) { if((y.second - x.second) * (z.first - y.first) < (y.first - x.first) * (z.second - y.second)) return 1; else return 0; } int binary_search(pll x, LL k) { if(l == r) return l; int L = l, R = r; while(L < R) { int mid = (L + R) >> 1; if((q[mid + 1].second - q[mid].second) <= k * (q[mid + 1].first - q[mid].first)) L = mid + 1; else R = mid; } return L; } void dfs(int x, int fa) { sz[x] = 1; for (auto y : G[x]) { if(y == fa) continue; dfs(y, x); sz[x] += sz[y]; } for (auto y : G[x]) { if(y == fa) continue; dp[x] = min(dp[x], (sz[x] - sz[y]) * (sz[x] - sz[y]) + dp[y]); } tot = 0; for (auto y : G[x]) { if(y == fa) continue; b[++tot] = a[y]; } sort(b + 1, b + 1 + tot); mp.clear(); if(tot > 1) { l = 1, r = 0; for (int i = 1; i <= tot; i++) { mp[b[i]]++; while(l < r && !check(q[r - 1], q[r], b[i])) r--; q[++r] = b[i]; } for (int i = 1; i <= tot; i++) { int pos = binary_search(b[i], 2 * (n - b[i].first)); if(q[pos] == b[i]) { if(mp[b[i]] > 1) { ans = min(ans, b[i].second + q[pos].second + (n - b[i].first - q[pos].first) * (n - b[i].first - q[pos].first)); } else { if(pos < r) ans = min(ans, b[i].second + q[pos + 1].second + (n - b[i].first - q[pos + 1].first) * (n - b[i].first - q[pos + 1].first)); else ans = min(ans, b[i].second + q[pos - 1].second + (n - b[i].first - q[pos - 1].first) * (n - b[i].first - q[pos - 1].first)); } } else { ans = min(ans, b[i].second + q[pos].second + (n - b[i].first - q[pos].first) * (n - b[i].first - q[pos].first)); } } } for (int i = 1; i <= tot; i++) { ans = min(ans, b[i].second + (n - b[i].first) * (n - b[i].first)); } if(fa != -1 && G[x].size() == 1) dp[x] = sz[x] * sz[x]; a[x] = make_pair(sz[x], dp[x]); } int main() { int x, y; memset(dp, 0x3f, sizeof(dp)); // freopen("1179Din.txt", "r", stdin); // freopen("1179D1out.txt", "w", stdout); scanf("%lld", &n); for (int i = 1; i < n; i++) { scanf("%d%d", &x, &y); add(x, y); } ans = 1e18; dfs(1, -1); ans = min(ans, dp[1]); ans -= n; ans /= 2; ans = 2ll * n * (n - 1) / 2ll - ans; printf("%lld\n", ans); }