倍增/线段树维护树的直径 hdu5993/2016icpc青岛L
题意:
给一棵树,每次询问删掉两条边,问剩下的三棵树的最大直径
点10W,询问10W,询问相互独立
Solution:
考虑线段树/倍增维护树的直径
考虑一个点集的区间 [l, r]
而我们知道了有 l <= k < r,
且知道 [l, k] 和 [k + 1, r] 两个区间的最长链的端点及长度
假设两个区间的直径端点分别为 (l1, r1) 和 (l2, r2)
那么 [l, r] 这个区间的直径长度为
dis(l1, r1) dis(l1, l1) dis(l1, r2)
dis(r1, l2) dis(r1, r2) dis(l2, r2)
六个值中的最大值
本题因为操作子树,所以我们维护dfs序的区间最长链即可
证明:
首先有一个结论:
树上任意一个点在树中的最远点是树的直径的某个端点。我们可以用反证法轻易地证明这一点。
再扩展一下,有以下结论:树上任意一个点在树中的一个点集中的最远点是该点集中最长链的一个端点。
其实我们把点集等价地看为一棵虚树,然后就能用相似的证法解决了。
代码:

1 #include <stdio.h> 2 #include <algorithm> 3 4 using namespace std; 5 6 const int N = 2e5 + 5; 7 8 int T, n, m; 9 10 int len, head[N], ST[20][N]; 11 12 struct edge{int u, v, w;}ee[N]; 13 14 int cnt, fa[N], log_2[N], st[N], en[N], dfn[N], dis[N], dep[N], pos[N]; 15 16 struct edges{int to, next, cost;}e[N]; 17 18 inline void add(int u, int v, int w) { 19 e[++ len] = (edges){v, head[u], w}, head[u] = len; 20 e[++ len] = (edges){u, head[v], w}, head[v] = len; 21 } 22 23 inline void dfs1(int u) { 24 st[u] = ++ cnt, dfn[cnt] = u; 25 for (int v, i = head[u]; i; i = e[i].next) { 26 v = e[i].to; 27 if (v == fa[u]) continue; 28 fa[v] = u, dep[v] = dep[u] + 1; 29 dis[v] = dis[u] + e[i].cost, dfs1(v); 30 } 31 en[u] = cnt; 32 } 33 34 inline void dfs2(int u) { 35 dfn[++ cnt] = u, pos[u] = cnt; 36 for (int v, i = head[u]; i; i = e[i].next) { 37 v = e[i].to; 38 if (v == fa[u]) continue; 39 dfs2(v), dfn[++ cnt] = u; 40 } 41 } 42 43 int mmin(int x, int y) { 44 if (dep[x] < dep[y]) return x; 45 return y; 46 } 47 48 inline int lca(int u, int v) { 49 static int w; 50 if (pos[u] > pos[v]) swap(u, v); 51 w = log_2[pos[v] - pos[u] + 1]; 52 return mmin(ST[w][pos[u]], ST[w][pos[v] - (1 << w) + 1]); 53 } 54 55 inline int dist(int u, int v) { 56 int Lca = lca(u, v); 57 return dis[u] + dis[v] - dis[Lca] * 2; 58 } 59 60 inline void build() { 61 for (int i = 1; i <= cnt; i ++) 62 ST[0][i] = dfn[i]; 63 for (int i = 1; i < 20; i ++) 64 for (int j = 1; j <= cnt; j ++) 65 if (j + (1 << (i - 1)) > cnt) ST[i][j] = ST[i - 1][j]; 66 else ST[i][j] = mmin(ST[i - 1][j], ST[i - 1][j + (1 << (i - 1))]); 67 } 68 69 int M; 70 71 struct node { 72 int l, r, dis; 73 }tr[N << 1]; 74 75 inline void update(int o, int o1, int o2) { 76 static int d; 77 static node tmp; 78 if (tr[o1].dis == -1) {tr[o] = tr[o2]; return;} 79 if (tr[o2].dis == -1) {tr[o] = tr[o1]; return;} 80 if (tr[o1].dis > tr[o2].dis) tmp = tr[o1]; 81 else tmp = tr[o2]; 82 d = dist(tr[o1].l, tr[o2].l); 83 if (d > tmp.dis) tmp.l = tr[o1].l, tmp.r = tr[o2].l, tmp.dis = d; 84 d = dist(tr[o1].l, tr[o2].r); 85 if (d > tmp.dis) tmp.l = tr[o1].l, tmp.r = tr[o2].r, tmp.dis = d; 86 d = dist(tr[o1].r, tr[o2].l); 87 if (d > tmp.dis) tmp.l = tr[o1].r, tmp.r = tr[o2].l, tmp.dis = d; 88 d = dist(tr[o1].r, tr[o2].r); 89 if (d > tmp.dis) tmp.l = tr[o1].r, tmp.r = tr[o2].r, tmp.dis = d; 90 tr[o] = tmp; 91 } 92 93 inline void ask(int s, int t) { 94 if (s > t) return; 95 for (s += M - 1, t += M + 1; s ^ t ^ 1; s >>= 1, t >>= 1) { 96 if (~s&1) update(0, 0, s ^ 1); 97 if ( t&1) update(0, 0, t ^ 1); 98 } 99 } 100 101 inline int get_char() { 102 static const int SIZE = 1 << 23; 103 static char *T, *S = T, buf[SIZE]; 104 if (S == T) { 105 T = fread(buf, 1, SIZE, stdin) + (S = buf); 106 if (S == T) return -1; 107 } 108 return *S ++; 109 } 110 111 inline void in(int &x) { 112 static int ch; 113 while (ch = get_char(), ch > 57 || ch < 48);x = ch - 48; 114 while (ch = get_char(), ch > 47 && ch < 58) x = x * 10 + ch - 48; 115 } 116 117 int main() { 118 int u, v, w, ans; 119 log_2[1] = 0; 120 for (int i = 2; i <= 200000; i ++) 121 if (i == 1 << (log_2[i - 1] + 1)) 122 log_2[i] = log_2[i - 1] + 1; 123 else log_2[i] = log_2[i - 1]; 124 for (in(T); T --; ) { 125 in(n), in(m), cnt = len = 0; 126 for (int i = 1; i <= n; i ++) 127 head[i] = 0; 128 for (int i = 1; i < n; i ++) { 129 in(ee[i].u), in(ee[i].v), in(ee[i].w); 130 add(ee[i].u, ee[i].v, ee[i].w); 131 } 132 dfs1(1); 133 for (M = 1; M < n + 2; M <<= 1); 134 for (int i = 1; i <= n; i ++) 135 tr[i + M].l = tr[i + M].r = dfn[i], tr[i + M].dis = 0; 136 for (int i = n + M + 1; i <= (M << 1) + 1; i ++) 137 tr[i].dis = -1; 138 cnt = 0, dfs2(1), build(); 139 for (int i = M; i; i --) 140 update(i, i << 1, i << 1 | 1); 141 for (int i = 1; i < n; i ++) 142 if (dep[ee[i].u] > dep[ee[i].v]) 143 swap(ee[i].u, ee[i].v); 144 for (int u, v, i = 1; i <= m; i ++) { 145 in(u), in(v), ans = 0; 146 u = ee[u].v, v = ee[v].v, w = lca(u, v); 147 if (w == u || w == v) { 148 if (w != u) swap(u, v); 149 tr[0].dis = -1, ask(1, st[u] - 1), ask(en[u] + 1, n), ans = max(ans, tr[0].dis); 150 tr[0].dis = -1, ask(st[u], st[v] - 1), ask(en[v] + 1, en[u]), ans = max(ans, tr[0].dis); 151 tr[0].dis = -1, ask(st[v], en[v]), ans = max(ans, tr[0].dis); 152 } 153 else { 154 if (st[u] > st[v]) swap(u, v); 155 tr[0].dis = -1, ask(1, st[u] - 1), ask(en[u] + 1, st[v] - 1), ask(en[v] + 1, n), ans = max(ans, tr[0].dis); 156 tr[0].dis = -1, ask(st[u], en[u]), ans = max(ans, tr[0].dis); 157 tr[0].dis = -1, ask(st[v], en[v]), ans = max(ans, tr[0].dis); 158 } 159 printf("%d\n", ans); 160 } 161 } 162 return 0; 163 }
一开始没带脑子算错了复杂度,少算了个log开心的写了树剖LCA,还在dfs的时候求siz忘记把儿子的siz加上了
T到死...发现是带2个log,该死出题人多组数据不给数据组数,改写ST表O(1)求LCA,复杂度只带1个log过了
理论上线段树也可以用ST表代替,复杂度O(n)...当然不可能啦,预处理nlogn,回答O(1)
附加训练 51nod 1766

1 #include <stdio.h> 2 #include <algorithm> 3 4 using namespace std; 5 6 const int N = 1e5 + 5; 7 8 int n, m, M, tot, head[N], st[18][N << 1], log_2[N << 1]; 9 10 int cnt, dis[N], dep[N], pos[N], dfn[N << 1]; 11 12 struct edge{int to, next, cost;}e[N << 1]; 13 14 int mmin(int x, int y) { 15 return dep[x] < dep[y] ? x : y; 16 } 17 18 void add(int u, int v, int w) { 19 e[++ tot] = (edge){v, head[u], w}, head[u] = tot; 20 e[++ tot] = (edge){u, head[v], w}, head[v] = tot; 21 } 22 23 void dfs(int u, int fr) { 24 dfn[++ cnt] = u, pos[u] = cnt; 25 for (int v, i = head[u]; i; i = e[i].next) { 26 v = e[i].to; 27 if (v == fr) continue; 28 dep[v] = dep[u] + 1, dis[v] = dis[u] + e[i].cost; 29 dfs(v, u), dfn[++ cnt] = u; 30 } 31 } 32 33 int lca(int u, int v) { 34 if (pos[u] > pos[v]) swap(u, v); 35 int w = log_2[pos[v] - pos[u] + 1]; 36 return mmin(st[w][pos[u]], st[w][pos[v] - (1 << w) + 1]); 37 } 38 39 int dist(int u, int v) { 40 return dis[u] + dis[v] - dis[lca(u, v)] * 2; 41 } 42 43 struct node { 44 int l, r, dis; 45 46 node operator + (const node &a) const { 47 node res; 48 if (dis == -1) return a; 49 if (a.dis == -1) return *this; 50 if (dis > a.dis) res = *this; 51 else res = a; 52 int d = dist(l, a.l); 53 if (d > res.dis) res.l = l, res.r = a.l, res.dis = d; 54 d = dist(l, a.r); 55 if (d > res.dis) res.l = l, res.r = a.r, res.dis = d; 56 d = dist(r, a.l); 57 if (d > res.dis) res.l = r, res.r = a.l, res.dis = d; 58 d = dist(r, a.r); 59 if (d > res.dis) res.l = r, res.r = a.r, res.dis = d; 60 return res; 61 } 62 63 node operator * (const node &a) const { 64 node res; res.dis = -1; 65 int d = dist(l, a.l); 66 if (d > res.dis) res.l = l, res.r = a.l, res.dis = d; 67 d = dist(l, a.r); 68 if (d > res.dis) res.l = l, res.r = a.r, res.dis = d; 69 d = dist(r, a.l); 70 if (d > res.dis) res.l = r, res.r = a.l, res.dis = d; 71 d = dist(r, a.r); 72 if (d > res.dis) res.l = r, res.r = a.r, res.dis = d; 73 return res; 74 } 75 }tr[N << 2]; 76 77 node ask(int s, int t) { 78 node res; res.dis = -1; 79 for (s += M - 1, t += M + 1; s ^ t ^ 1; s >>= 1, t >>= 1) { 80 if (~s&1) res = res + tr[s ^ 1]; 81 if ( t&1) res = res + tr[t ^ 1]; 82 } 83 return res; 84 } 85 86 int main() { 87 scanf("%d", &n); 88 for (int u, v, w, i = 1; i < n; i ++) 89 scanf("%d %d %d", &u, &v, &w), add(u, v, w); 90 dfs(1, 1); 91 92 for (int i = 1; i <= cnt; i ++) 93 st[0][i] = dfn[i]; 94 for (int i = 1; i < 18; i ++) 95 for (int j = 1; j <= cnt; j ++) 96 if (j + (1 << (i - 1)) > cnt) st[i][j] = st[i - 1][j]; 97 else st[i][j] = mmin(st[i - 1][j], st[i - 1][j + (1 << (i - 1))]); 98 log_2[1] = 0; 99 for (int i = 2; i <= cnt; i ++) 100 log_2[i] = log_2[i - 1] + (i == (1 << (log_2[i - 1] + 1))); 101 102 for (M = 1; M < n + 2; M <<= 1); 103 for (int i = 1; i <= n; i ++) tr[i + M] = (node){i, i, 0}; 104 for (int i = n + 1; i <= M + 1; i ++) tr[i + M].dis = -1; 105 for (int i = M; i; i --) tr[i] = tr[i << 1] + tr[i << 1 | 1]; 106 107 node tmp; int a, b, c, d; 108 for (scanf("%d", &m); m --; ) { 109 scanf("%d %d %d %d", &a, &b, &c, &d); 110 tmp = ask(a, b) * ask(c, d); 111 printf("%d\n", tmp.dis); 112 } 113 return 0; 114 }
相对简单一点了