SPOJ_1825
树的分治的题目,详细的算法可以参考09年漆子超的论文《分治算法在树的路径问题中的应用》。
#include<stdio.h> #include<string.h> #include<algorithm> #include<vector> #define MAXD 200010 #define MAXM 400010 #define INF 0x3f3f3f3f int N, K, M, first[MAXD], e, next[MAXM], v[MAXM], w[MAXM], col[MAXD]; int q[MAXD], size[MAXD], ANS, pre[MAXD], del[MAXD], num[MAXD], sum[MAXD]; std::vector<int> f[MAXD]; struct St { int id, size; St(){} St(int _id, int _size) : id(_id), size(_size){} bool operator < (const St &t) const { return size < t.size; } }; void add(int x, int y, int z) { v[e] = y, w[e] = z; next[e] = first[x], first[x] = e ++; } void init() { int i, x, y, z; memset(col, 0, sizeof(col[0]) * (N + 1)); for(i = 0; i < M; i ++) scanf("%d", &x), col[x] = 1; memset(first, -1, sizeof(first[0]) * (N + 1)), e = 0; for(i = 1; i < N; i ++) { scanf("%d%d%d", &x, &y, &z); add(x, y, z), add(y, x, z); } } int findroot(int cur) { int i, j, max = INF, rear = 0, root; q[rear ++] = cur, pre[cur] = -1; for(i = 0; i < rear; i ++) { int x = q[i]; for(j = first[x]; j != -1; j = next[j]) if(!del[v[j]] && v[j] != pre[x]) pre[v[j]] = x, q[rear ++] = v[j]; } for(i = rear - 1; i >= 0; i --) { int x = q[i], m = 0; size[x] = 1; for(j = first[x]; j != -1; j = next[j]) if(!del[v[j]] && v[j] != pre[x]) size[x] += size[v[j]], m = std::max(m, size[v[j]]); m = std::max(m, rear - size[x]); if(m < max) root = x, max = m; } return root; } void refresh(int cur, int d) { int i, j, x, rear = 0; f[cur].push_back(0); pre[cur] = -1, q[rear] = cur, num[rear] = col[cur], sum[rear] = d, ++ rear; for(i = 0; i < rear; i ++) { if(num[i] > K) continue; x = q[i]; if(num[i] >= f[cur].size()) f[cur].push_back(sum[i]); else f[cur][num[i]] = std::max(f[cur][num[i]], sum[i]); for(j = first[x]; j != -1; j = next[j]) if(!del[v[j]] && v[j] != pre[x]) pre[v[j]] = x, q[rear] = v[j], num[rear] = num[i] + col[v[j]], sum[rear] = sum[i] + w[j], ++ rear; } for(i = 1; i < f[cur].size(); i ++) f[cur][i] = std::max(f[cur][i], f[cur][i - 1]); } void dfs(int cur, int d) { int i, j, root = findroot(cur); del[root] = 1; std::vector<St> st; for(i = first[root]; i != -1; i = next[i]) if(!del[v[i]]) { dfs(v[i], w[i]); st.push_back(St(v[i], f[v[i]].size())); } std::sort(st.begin(), st.end()); for(i = 0; i < st.size(); i ++) { int x = st[i].id; for(j = 0; j < f[x].size(); j ++) { if(j + col[root] <= K) ANS = std::max(ANS, f[x][j]); if(i) { int y = st[i - 1].id, t = std::min(K - j - col[root], (int)f[y].size() - 1); if(t >= 0) ANS = std::max(ANS, f[x][j] + f[y][t]); t = std::min((int)f[y].size() - 1, j); f[x][j] = std::max(f[x][j], f[y][t]); } } } del[root] = 0; refresh(cur, d); } void solve() { ANS = 0; memset(del, 0, sizeof(del[0]) * (N + 1)); dfs(1, 0); printf("%d\n", ANS); } int main() { while(scanf("%d%d%d", &N, &K, &M) == 3) { init(); solve(); } return 0; }