HDU 5293 Tree chain problem 树形dp+dfs序+树状数组+LCA
题目链接:
http://acm.hdu.edu.cn/showproblem.php?pid=5293
题意:
给你一些链,每条链都有自己的价值,求不相交不重合的链能够组成的最大价值。
题解:
树形dp,
对于每条链u,v,w,我们只在lca(u,v)的顶点上处理它
让dp[i]表示以i为根的子树的最大值,sum[i]表示dp[vi]的和(vi为i的儿子们)
则i点有两种决策,一种是不选以i为lca的链,则dp[i]=sum[i]。
另一种是选一条以i为lca的链,那么有转移方程:dp[i]=sigma(dp[vj])+sigma(sum[kj])+w。(sigma表示累加,vj表示那些不在链上的孩子们,kj表示在链上的孩子们)
为了便于计算,我们处理出dp[i]=sum[i]-sigma(dp[k]-sum[k])+w=sum[i]+sigma(sum[k]-dp[k])+w。
利用dfs序和树状数组可以logn算出sigma(sum[k]-dp[k])。
#pragma comment(linker,"/STACK:1024000000,1024000000") #include<algorithm> #include<iostream> #include<cstring> #include<cstdio> #include<vector> using namespace std; const int maxn = 202020; const int maxb= 22; struct Node { int u, v, w; Node(int u, int v, int w) :u(u), v(v), w(w) {} }; int n, m; vector<Node> que[maxn]; vector<int> G[maxn]; int lca[maxn][maxb],in[maxn],out[maxn],dep[maxn],dfs_cnt; int sumv[maxn]; int dp[maxn],sum[maxn]; //计算dfs序,in,out;预处理每个订点的祖先lca[i][j],表示i上面第2^j个祖先,lca[i][0]表示父亲 void dfs(int u, int fa,int d) { in[u] = ++dfs_cnt; lca[u][0] = fa; dep[u] = d; for (int j = 1; j < maxb; j++) { int f = lca[u][j - 1]; lca[u][j] = lca[f][j - 1]; } for (int i = 0; i < G[u].size(); i++) { int v = G[u][i]; if (v == fa) continue; dfs(v, u, d + 1); } out[u] = ++dfs_cnt; } //在线lca,o(n*logn)预处理+o(logn)询问 int Lca(int u, int v) { if (dep[u] < dep[v]) swap(u, v); //二进制倍增法,u,v提到相同高度 for (int i = maxb - 1; i >= 0; i--) { if (dep[lca[u][i]] >= dep[v]) u = lca[u][i]; } //当lca为u或者为v的时候 if (u == v) return u; //lca不是u也不是v的情况 //一起往上提 for (int i = maxb - 1; i >= 0; i--) { if (lca[u][i] != lca[v][i]) { u = lca[u][i]; v = lca[v][i]; } } return lca[u][0]; } //树状数组 int get_sum(int x) { int ret = 0; while (x > 0) { ret += sumv[x]; x -= x&(-x); } return ret; } void add(int x, int v) { while (x <maxn) { sumv[x] += v; x += x&(-x); } } //树形dp(用到dfs序和树状数组来快速计算链) //dfs序+树状数组的想法可以自己在纸上画画图, void solve(int u,int fa) { for (int i = 0; i < G[u].size(); i++) { int v = G[u][i]; if (v == fa) continue; solve(v, u); sum[u] += dp[v]; } dp[u] = sum[u]; for (int i = 0; i < que[u].size(); i++) { Node& nd = que[u][i]; //get_sum(in[nd.u])处理的是lca(u,v)到u点这条路径的所有顶点 //get_sum(out[nd.v])处理的是lca(u,v)到v点这条路径的所有顶点 dp[u] = max(dp[u], sum[u] + get_sum(in[nd.u]) + get_sum(in[nd.v]) + nd.w); } add(in[u], sum[u] - dp[u]); add(out[u], dp[u] - sum[u]); } void init() { dfs_cnt = 0; for (int i = 1; i <= n; i++) G[i].clear(); for (int i = 1; i <= n; i++) que[i].clear(); memset(lca, 0, sizeof(lca)); memset(sumv, 0, sizeof(sumv)); memset(sum, 0, sizeof(sum)); memset(dp, 0, sizeof(dp)); } int main() { int tc; scanf("%d", &tc); while (tc--) { scanf("%d%d", &n, &m); init(); for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } dfs(1, 0,1); while (m--) { int u, v, w; scanf("%d%d%d", &u, &v, &w); //每条链在Lca的位置上处理,这样符合dp的无后效性 que[Lca(u, v)].push_back(Node(u, v, w)); } solve(1, 0); printf("%d\n", dp[1]); } return 0; } /* 1 7 111 1 2 1 3 2 4 2 5 3 6 3 7 3 3 1 2 1 3 1 1 1 2 2 2 3 3 3 3 3 1 2 1 3 1 1 1 1 2 3 3 3 1 */