zoj 3820(2014牡丹江现场赛B题)
题目链接:http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemId=5374
思路:题目的意思是求树上的两点,使得树上其余的点到其中一个点的最长距离最小。可以想到这题与树直径有关,我们可以这样做,首先求出树的直径,然后取出树的中点以及与该中点相邻,并且是直径上的一个点,这样就把这棵树划分为两颗子树,然后分别求出这两棵树的直径,最后要选择的两个点分别就是这两棵树的直径上的中点。
一开始是用dfs写的,结果爆栈了,改成bfs就过了。
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <queue> using namespace std; const int MAX_N = (200000 + 20000); struct Edge { int v, w, next; } edge[MAX_N << 1]; int N, NE, head[MAX_N]; void Init() { NE = 0; memset(head, -1, sizeof(head)); } void Insert(int u, int v, int w) { edge[NE].v = v; edge[NE].w = w; edge[NE].next = head[u]; head[u] = NE++; } int dep[MAX_N], path[MAX_N], st, ed, s_mid, e_mid; int ans_minDist, ans_point1, ans_point2; bool vis[MAX_N]; bool check(int u, int v) { if (u == s_mid && v == e_mid) return true; if (u == e_mid && v == s_mid) return true; return false; } void bfs(int u, int fa, int deep) { dep[u] = deep; path[u] = fa; vis[u] = true; queue<int > que; que.push(u); while (!que.empty()) { int u = que.front(); que.pop(); for (int i = head[u]; ~i; i = edge[i].next) { int v = edge[i].v, w = edge[i].w; if (v == fa || check(u, v) || vis[v]) continue; dep[v] = dep[u] + w; path[v] = u; vis[v] = true; que.push(v); } } } void gao() { s_mid = e_mid = -1; memset(vis, false, sizeof(vis)); bfs(1, -1, 0); int max_deep = -1; for (int i = 1; i <= N; ++i) { if (dep[i] > max_deep) max_deep = dep[i], st = i; } memset(vis, false, sizeof(vis)); bfs(st, -1, 0); max_deep = -1; for (int i = 1; i <= N; ++i) { if (dep[i] > max_deep) max_deep = dep[i], ed = i; } int tmp = ed, cnt = 0; while (tmp != -1) { tmp = path[tmp]; ++cnt; if (cnt == max_deep / 2) s_mid = tmp; else if (cnt == max_deep / 2 + 1) e_mid = tmp; } } void solve() { //get point1 memset(vis, false, sizeof(vis)); bfs(s_mid, e_mid, 0); int max_deep = -1; for (int i = 1; i <= N; ++i) { if (vis[i] && dep[i] > max_deep) max_deep = dep[i], st = i; } memset(vis, false, sizeof(vis)); bfs(st, -1, 0); max_deep = -1; for (int i = 1; i <= N; ++i) { if (vis[i] && dep[i] > max_deep) max_deep = dep[i], ed = i; } int tmp = ed, cnt = 0; ans_point1 = ed; while (tmp != -1) { tmp = path[tmp]; ++cnt; if (cnt == max_deep / 2) ans_point1 = tmp; } memset(vis, false, sizeof(vis)); bfs(ans_point1, -1, 0); max_deep = -1; for (int i = 1; i <= N; ++i) { if (vis[i] && dep[i] > max_deep) max_deep = dep[i]; } ans_minDist = max_deep; //get point2 memset(vis, false, sizeof(vis)); bfs(e_mid, s_mid, 0); max_deep = -1; for (int i = 1; i <= N; ++i) { if (vis[i] && dep[i] > max_deep) max_deep = dep[i], st = i; } memset(vis, false, sizeof(vis)); bfs(st, -1, 0); max_deep = -1; for (int i = 1; i <= N; ++i) { if (vis[i] && dep[i] > max_deep) max_deep = dep[i], ed = i; } tmp = ed, ans_point2 = ed, cnt = 0; while (tmp != -1) { tmp = path[tmp]; ++cnt; if (cnt == max_deep / 2) ans_point2 = tmp; } memset(vis, false, sizeof(vis)); bfs(ans_point2, -1, 0); max_deep = -1; for (int i = 1; i <= N; ++i) { if (vis[i] && dep[i] > max_deep) max_deep = dep[i]; } ans_minDist = max(ans_minDist, max_deep); } int main() { int Cas; scanf("%d", &Cas); while (Cas--) { scanf("%d", &N); Init(); for (int i = 1; i < N; ++i) { int u, v; scanf("%d %d", &u, &v); Insert(u, v, 1); Insert(v, u, 1); } if (N == 2) { puts("0 1 2"); continue; } gao(); solve(); printf("%d %d %d\n", ans_minDist, ans_point1, ans_point2); } }