poj 3728(LCA + dp)
题目链接:http://poj.org/problem?id=3728
思路:题目的意思是求树上a -> b的路径上的最大收益(在最小值买入,在最大值卖出)。
我们假设路径a - > b 之间的LCA(a, b) = f, 并且另up[a]表示a - > f之间的最大收益,down[a]表示f - > a之间的最大收益,dp_max[a]表示a - > f之间的最大值,dp_min[a]表示a - > f之间的最小值,于是可以得出关系: ans[id] = max(max(up[a], down[b]), dp_max[b] - dp_min[a])。
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int MAX_N = (50000 + 5000); const int MAX_M = (MAX_N << 2); const int inf = 0x3f3f3f3f; int NE1, NE2, NE3, head1[MAX_N], head2[MAX_N], head3[MAX_N]; void Init() { NE1 = NE2 = NE3 = 0; memset(head1, -1, sizeof(head1)); memset(head2, -1, sizeof(head2)); memset(head3, -1, sizeof(head3)); } int N, Q, ans[MAX_N], value[MAX_N], vis[MAX_N]; struct Edge1 { int v, next; } edge1[MAX_M]; void Insert1(int u, int v) { edge1[NE1].v = v; edge1[NE1].next = head1[u]; head1[u] = NE1++; } struct Edge { int v, id, next; } edge2[MAX_M], edge3[MAX_M]; void Insert2(int u, int v, int id, int flag) { if (!flag) { edge2[NE2].v = v; edge2[NE2].id = id; edge2[NE2].next = head2[u]; head2[u] = NE2++; } else { edge3[NE3].v = v; edge3[NE3].id = id; edge3[NE3].next = head3[u]; head3[u] = NE3++; } } int parent[MAX_N]; int up[MAX_N], down[MAX_N], dp_max[MAX_N], dp_min[MAX_N]; int find(int x) { if (x == parent[x]) { return x; } int fa = parent[x]; parent[x] = find(parent[x]); up[x] = max(max(up[x], up[fa]), dp_max[fa] - dp_min[x]); down[x] = max(max(down[x], down[fa]), dp_max[x] - dp_min[fa]); dp_max[x] = max(dp_max[x], dp_max[fa]); dp_min[x] = min(dp_min[x], dp_min[fa]); return parent[x]; } struct Node { int u, v; } node[MAX_N]; void Tarjan(int u) { vis[u] = 1; parent[u] = u; //Q; for (int i = head2[u]; ~i; i = edge2[i].next) { int v = edge2[i].v, id = edge2[i].id; if (!vis[v]) continue; int fa = find(v); Insert2(fa, v, id, 1); } for (int i = head1[u]; ~i; i = edge1[i].next) { int v = edge1[i].v; if (vis[v]) continue; Tarjan(v); parent[v] = u; } //edge3 for (int i = head3[u]; ~i; i = edge3[i].next) { int id = edge3[i].id; find(node[id].u); find(node[id].v); ans[id] = max(max(up[node[id].u], down[node[id].v]), dp_max[node[id].v] - dp_min[node[id].u]); } } int main() { while (~scanf("%d", &N)) { for (int i = 1; i <= N; ++i) { scanf("%d", &value[i]); up[i] = down[i] = 0; dp_max[i] = dp_min[i] = value[i]; } Init(); for (int i = 1; i < N; ++i) { int u, v; scanf("%d %d", &u, &v); Insert1(u, v); Insert1(v, u); } scanf("%d", &Q); for (int i = 1; i <= Q; ++i) { scanf("%d %d", &node[i].u, &node[i].v); Insert2(node[i].u, node[i].v, i, 0); Insert2(node[i].v, node[i].u, i, 0); } memset(vis, 0, sizeof(vis)); Tarjan(1); for (int i = 1; i <= Q; ++i) printf("%d\n", ans[i]); } return 0; }