换根dp
换根dp
描述:
题目初始不给出根节点,需要以每一个节点为根节点遍历,打到某一个根节点的最佳情况。
暴力的思路是循环每个节点做为根节点,同时在遍历整个树,时间复杂度是O(n^2)
换根dp主要是将时间复杂度降到O(n),在根节点切换时,直接通过一些已经计算过的数据在O(1)就能得到另一个根的结果。
一般是通过二次扫描,第一次dfs获得预处理数据,第二次dfs进行根节点的切换。
例题
E - Minimize Sum of Distances
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 1e5 + 5;
vector<int> adj[maxn];
void solve() {
int n;
cin >> n;
for (int i = 0; i < n - 1; i++) {
int a, b;
cin >> a >> b;
adj[a].push_back(b);
adj[b].push_back(a);
}
vector<int> v(n + 1);
for (int i = 1; i <= n; i++) {
cin >> v[i];
}
vector<int> fa(n + 1);
vector<int> p (n + 1);
vector<int> cnt(n + 1);
function<void(int)> dfs =[&] (int x) {
cnt[x] = v[x];
p[x] = 0;
for (auto it : adj[x]) {
if(it != fa[x]) {
fa[it] = x;
dfs(it);
cnt[x] += cnt[it];
p[x] += cnt[it] + p[it];
}
}
};
dfs(1);
int res = p[1];
vector<int> f(n + 1);
f[1] = p[1];
function<void(int, int)> dfs2=[&](int u, int v) {
if(v != 1) {
f[v] = f[u] + cnt[1] - cnt[v] - cnt[v];
res = min(res, f[v]);
}
for (auto it : adj[v]) {
if(it != fa[v]) {
dfs2(v, it);
}
}
};
dfs2(1, 1);
cout << res << endl;
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
int T = 1;
while(T--) {
solve();
}
return 0;
}
积蓄程度
卡map
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 2e5 + 5;
vector<pair<int, int>> adj[maxn];
void solve() {
int n;
cin >> n;
for (int i = 1; i <= n; i++) {
adj[i].clear();
}
map<pair<int, int>, int> mp;
for (int i = 0; i < n - 1; i++) {
int a, b, c;
cin >> a >> b >> c;
adj[a].push_back({b, c});
adj[b].push_back({a, c});
}
vector<int> p(n + 1);
vector<int> f(n + 1);
function<int(int, int, int)> dfs1 =[&] (int u, int fa, int e) {
p[u] = 0;
if(adj[u].size() == 1 && adj[u][0].first == fa) {
return e;
}
for (auto it : adj[u]) {
if(it.first != fa) {
p[u] += dfs1(it.first, u, it.second);
}
}
return min(p[u], e);
};
p[1] = dfs1(1, 0, 1e18);
f[1] = p[1];
int ans = p[1];
function<void(int, int, int e)> dfs2 =[&] (int u, int v, int e) {
if(v != 1) {
if(adj[u].size() == 1) {
f[v] = p[v] + e;
}
else {
f[v] = min(f[u] - min(p[v], e), e) + p[v];
}
ans = max(f[v], ans);
}
for (auto it : adj[v]) {
if(it.first != u) {
dfs2(v, it.first, it.second);
}
}
};
dfs2(0, 1, 0);
cout << ans << endl;
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
int T = 1;
cin >> T;
while(T--) {
solve();
}
return 0;
}
P3478 [POI2008] STA-Station
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 1e6 + 5;
vector<int> adj[maxn];
void solve() {
int n;
cin >> n;
for (int i = 0; i < n - 1; i++) {
int a, b;
cin >> a >> b;
adj[a].push_back(b);
adj[b].push_back(a);
}
vector<int> fa(n + 1);
vector<int> p(n + 1);
vector<int> f(n + 1);
vector<int> cnt(n + 1);
function<void(int)> dfs1 =[&](int u) {
p[u] = 1;
cnt[u] = 1;
if(adj[u].size() == 1 && adj[u][0] == fa[u]) {
return ;
}
for (auto it : adj[u]) {
if(it != fa[u]) {
fa[it] = u;
dfs1(it);
p[u] += p[it] + cnt[it];
cnt[u] += cnt[it];
}
}
};
dfs1(1);
int res = p[1];
int id = 1;
f[1] = p[1];
function<void(int, int v)> dfs2=[&](int u, int v) {
if(v != 1) {
f[v] = f[u] - cnt[v] + cnt[1] - cnt[v];
if(f[v] > res) {
id = v;
res = f[v];
}
}
for (auto it : adj[v]) {
if(it != fa[v]) {
dfs2(v, it);
}
}
};
dfs2(1, 1);
cout << id << endl;
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
int T = 1;
while(T--) {
solve();
}
return 0;
}
P3047 [USACO12FEB] Nearby Cows G
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 1e5 + 5;
vector<int> adj[maxn];
int dp[maxn][25];
void solve() {
int n, k;
cin >> n >> k;
vector<int> v(n + 1);
for (int i = 0; i < n - 1; i++) {
int a, b;
cin >> a >> b;
adj[a].push_back(b);
adj[b].push_back(a);
}
for (int i = 1; i <= n; i++) {
cin >> v[i];
for (int j = 0; j <= k; j++) {
dp[i][j] = 0;
}
}
function<void(int, int)> dfs1 =[&] (int u, int fa) {
dp[u][0] = v[u];
for (auto it : adj[u]) {
if(it != fa) {
dfs1(it, u);
for (int i = 1; i <= k; i++) {
dp[u][i] += dp[it][i - 1];
}
}
}
};
dfs1(1, 0);
vector<int> ans(n + 1);
for (int i = 1; i <= n; i++) {
for (int j = 0; j < k; j++) {
dp[i][j + 1] += dp[i][j];
}
}
vector<int> vt;
function<void(int, int)> dfs2 =[&] (int u, int fa) {
int res = 0;
vt.push_back(u);
for (int i = vt.size() - 2, j = 0; i >= 0 && j < k; i--, j++) {
res += dp[vt[i]][k - j - 1] - dp[vt[i + 1]][k - j - 2];
}
ans[u] = res + dp[u][k];
for (auto it : adj[u]) {
if(it != fa) {
dfs2(it, u);
}
}
vt.pop_back();
};
dfs2(1, 0);
for (int i = 1; i <= n; i++) {
cout << ans[i] << endl;
}
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
int T = 1;
// cin >> T;
while(T--) {
solve();
}
return 0;
}