bzoj 2599 [IOI2011]Race 点分治
题面
解法
记\(mn_i\)表示路径长度为\(i\)最少经过多少条边
然后进行点分治
假设当前的根为\(x\),那么先把\(x\)的子树一个一个访问,访问完一棵子树后更新答案
递归到下面的子树时把影响撤销
时间复杂度:\(O(n\ log\ n)\)
代码
#include <bits/stdc++.h>
#define inf 1 << 30
#define N 1000010
using namespace std;
template <typename node> void chkmax(node &x, node y) {x = max(x, y);}
template <typename node> void chkmin(node &x, node y) {x = min(x, y);}
template <typename node> void read(node &x) {
x = 0; int f = 1; char c = getchar();
while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
while (isdigit(c)) x = x * 10 + c - '0', c = getchar(); x *= f;
}
struct Edge {
int next, num, v;
} e[N * 3];
int n, K, ans, cnt, rt, now, vis[N], siz[N], f[N], d[N], s[N], mn[N];
void add(int x, int y, int v) {
e[++cnt] = (Edge) {e[x].next, y, v};
e[x].next = cnt;
}
void getr(int x, int fa) {
siz[x] = 1, f[x] = 0;
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num, v = e[p].v;
if (vis[k] || k == fa) continue;
getr(k, x); siz[x] += siz[k];
chkmax(f[x], siz[k]);
}
chkmax(f[x], now - siz[x]);
if (f[x] < f[rt]) rt = x;
}
void dfs(int x, int fa) {
if (d[x] <= K) chkmin(ans, mn[K - d[x]] + s[x]);
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num, v = e[p].v;
if (k == fa || vis[k]) continue;
d[k] = d[x] + v, s[k] = s[x] + 1;
dfs(k, x);
}
}
void update(int x, int fa) {
if (d[x] <= K) chkmin(mn[d[x]], s[x]);
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num, v = e[p].v;
if (k == fa || vis[k]) continue;
update(k, x);
}
}
void calc(int x, int v) {
d[x] = v, s[x] = 1;
dfs(x, 0); update(x, 0);
}
void Clear(int x, int fa) {
if (d[x] > K) return; mn[d[x]] = inf;
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num;
if (vis[k] || k == fa) continue;
Clear(k, x);
}
}
void work(int x) {
vis[x] = 1; mn[0] = 0;
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num, v = e[p].v;
if (vis[k]) continue; calc(k, v);
}
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num, v = e[p].v;
if (vis[k]) continue; Clear(k, x);
}
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num, v = e[p].v;
if (vis[k]) continue;
f[rt = 0] = inf, now = siz[k];
getr(k, x); work(k);
}
}
int main() {
read(n), read(K); cnt = n;
for (int i = 1; i < n; i++) {
int x, y, v;
read(x), read(y), read(v);
add(++x, ++y, v), add(y, x, v);
}
f[rt = 0] = inf, now = n, ans = inf;
for (int i = 1; i <= K; i++) mn[i] = inf;
getr(1, 0);
work(rt);
if (ans == inf) cout << "-1\n";
else cout << ans << "\n";
return 0;
}