SPOJ 1825 Free tour II 树分治
题意:
给出一颗边带权的数,树上的点有黑色和白色。求一条长度最大且黑色节点不超过k个的最长路径,输出最长的长度。
分析:
说一下题目的坑点:
- 定义递归函数的前面要加
inline
,否则会RE。不知道这是什么鬼,=_=|。 ans
要初始化为0,而不是一个绝对值很大的负数,因为我们可以选择只有一个顶点的路径,这样权值就是0。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#include <map>
#define MP make_pair
using namespace std;
typedef pair<int, int> PII;
const int maxn = 200000 + 10;
const int INF = 0x3f3f3f3f;
struct Edge
{
int v, w, nxt;
Edge() {}
Edge(int v, int w, int nxt): v(v), w(w), nxt(nxt) {}
};
int n, k, m, ans;
int black[maxn];
vector<int> a;
int ecnt, head[maxn];
Edge edges[maxn * 2];
inline void AddEdge(int u, int v, int w) {
edges[ecnt] = Edge(v, w, head[u]);
head[u] = ecnt++;
}
int fa[maxn], sz[maxn];
int dep[maxn], f[maxn], g[maxn], num[maxn];
bool del[maxn];
inline void dfs(int u) {
sz[u] = 1;
for(int i = head[u]; ~i; i = edges[i].nxt) {
int v = edges[i].v;
if(del[v] || v == fa[u]) continue;
fa[v] = u;
dfs(v);
sz[u] += sz[v];
}
}
inline PII findCenter(int u, int t) {
PII ans(INF, u);
int m = 0;
for(int i = head[u]; ~i; i = edges[i].nxt) {
int v = edges[i].v;
if(del[v] || v == fa[u]) continue;
m = max(m, sz[v]);
ans = min(ans, findCenter(v, t));
}
m = max(m, t - sz[u]);
return min(ans, MP(m, u));
}
inline void getdep(int u, int p) {
dep[u] = 0;
for(int i = head[u]; ~i; i = edges[i].nxt) {
int v = edges[i].v;
if(v == p || del[v]) continue;
getdep(v, u);
dep[u] = max(dep[u], dep[v]);
}
dep[u] += black[u];
}
bool cmp(int a, int b) { return dep[edges[a].v] < dep[edges[b].v]; }
inline void getg(int u, int p, int d, int c) {
g[c] = max(g[c], d);
for(int i = head[u]; ~i; i = edges[i].nxt) {
Edge& e = edges[i];
if(e.v == p || del[e.v]) continue;
getg(e.v, u, d + e.w, c + black[e.v]);
}
}
inline void solve(int u) {
fa[u] = 0;
dfs(u);
int s = findCenter(u, sz[u]).second;
del[s] = true;
for(int i = head[s]; ~i; i = edges[i].nxt) {
int v = edges[i].v;
if(del[v]) continue;
solve(v);
}
int tot = 0;
for(int i = head[s]; ~i; i = edges[i].nxt) {
int v = edges[i].v;
if(del[v]) continue;
getdep(v, s);
num[++tot] = i;
}
sort(num + 1, num + 1 + tot, cmp);
int mxdep = dep[edges[num[tot]].v];
for(int i = 0; i <= mxdep; i++) f[i] = -INF;
for(int i = 1; i <= tot; i++) {
int v = edges[num[i]].v, w = edges[num[i]].w;
int d = dep[v];
for(int j = 0; j <= d; j++) g[j] = -INF;
getg(v, s, w, black[v]);
if(i != 1) {
for(int j = 0; j <= k - black[s] && j <= d; j++) {
int j2 = min(dep[edges[num[i-1]].v], k - black[s] - j);
if(f[j2] == -INF) break;
if(g[j] != -INF) ans = max(ans, g[j] + f[j2]);
}
}
for(int j = 0; j <= d; j++) {
f[j] = max(f[j], g[j]);
if(j) f[j] = max(f[j], f[j-1]);
if(j + black[s] <= k) ans = max(ans, f[j]);
}
}
del[s] = false;
}
int main()
{
scanf("%d%d%d", &n, &k, &m);
while(m--) {
int x; scanf("%d", &x);
black[x] = 1;
}
ecnt = 0;
memset(head, -1, sizeof(head));
for(int i = 1; i < n; i++) {
int u, v, w; scanf("%d%d%d", &u, &v, &w);
AddEdge(u, v, w);
AddEdge(v, u, w);
}
ans = 0;
solve(1);
printf("%d\n", ans);
return 0;
}