DFS序+线段树 hihoCoder 1381 Little Y's Tree(树的连通块的直径和)
时间限制:24000ms
单点时限:4000ms
内存限制:512MB
描述
小Y有一棵n个节点的树,每条边都有正的边权。
小J有q个询问,每次小J会删掉这个树中的k条边,这棵树被分成k+1个连通块。小J想知道每个连通块中最远点对距离的和。
这里的询问是互相独立的,即每次都是在小Y的原树上进行操作。
输入
第一行一个整数n,接下来n-1行每行三个整数u,v,w,其中第i行表示第i条边边权为wi,连接了ui,vi两点。
接下来一行一个整数q,表示有q组询问。
对于每组询问,第一行一个正整数k,接下来一行k个不同的1到n-1之间的整数,表示删除的边的编号。
1<=n,q,Σk<=105, 1<=w<=109
输出
共q行,每行一个整数表示询问的答案。
题解:
首先考虑给出两个点集,如何求这两个点集合并之后的直径,方法是把两个点集的直径分别求出来,然后对于这4个点,求出两两之间距离的最大值。
于是可以按dfs序建立线段树,然后求出每个区间的直径。
而对于一个询问,删掉k条边,每棵子树都对应的dfs序中的若干区间,而且区间总个数不会超过2k,对于每个区间可以在线段树中查询。
时间复杂度O(nlog^2n)。
代码:
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int N = 1e5 + 5; const int D = 20; struct Edge { int u, v, w; }; int L[N], R[N], p[N], rt[N][D], dep[N]; ll d[N]; int dfs_clock; vector<Edge> edges; vector<int> id[N]; int n, m; void init_edge() { edges.clear(); for (int i=1; i<=n; ++i) id[i].clear(); m = 0; } void add_edge(int u, int v, int w) { edges.push_back((Edge){u, v, w}); m = edges.size(); id[u].push_back(m-1); } void DFS(int u, int fa) { L[u] = ++dfs_clock; p[dfs_clock] = u; dep[u] = dep[fa] + 1; rt[u][0] = fa; for (int i: id[u]) { Edge &e = edges[i]; if (e.v == fa) continue; d[e.v] = d[u] + e.w; DFS(e.v, u); } R[u] = dfs_clock; } void init_LCA() { for (int j=1; j<D; ++j) { for (int i=1; i<=n; ++i) { rt[i][j] = rt[rt[i][j-1]][j-1]; } } } int LCA(int u, int v) { if (dep[u] < dep[v]) swap(u, v); for (int i=0; i<D; ++i) { if ((dep[u]-dep[v]) >> i & 1) u = rt[u][i]; } if (u == v) return u; for (int i=D-1; i>=0; --i) { if (rt[u][i] != rt[v][i]) { u = rt[u][i]; v = rt[v][i]; } } return rt[u][0]; } ll dis(int u, int v) { return d[u] + d[v] - 2 * d[LCA(u, v)]; } struct Node { ll d; int a, b; Node(ll d=0, int a=0, int b=0) : d(d), a(a), b(b) {} bool operator < (const Node &rhs) const { return d < rhs.d; } }; Node nd[N<<2]; Node better(Node x, Node y) { if (x.d == -1) return y; if (y.d == -1) return x; Node z1 = Node(dis(x.a, y.a), x.a, y.a); Node z2 = Node(dis(x.a, y.b), x.a, y.b); Node z3 = Node(dis(x.b, y.a), x.b, y.a); Node z4 = Node(dis(x.b, y.b), x.b, y.b); return max({x, y, z1, z2, z3, z4}); } #define lch o << 1 #define rch o << 1 | 1 void build(int o, int l, int r) { if (l == r) { nd[o] = Node(0, p[l], p[l]); return ; } int mid = l + r >> 1; build(lch, l, mid); build(rch, mid+1, r); nd[o] = better(nd[lch], nd[rch]); } Node query(int o, int l, int r, int ql, int qr) { if (ql <= l && r <= qr) { return nd[o]; } int mid = l + r >> 1; Node ret = Node(-1, 0, 0); if (ql <= mid) ret = better(ret, query(lch, l, mid, ql, qr)); if (qr > mid) ret = better(ret, query(rch, mid+1, r, ql, qr)); return ret; } bool cmp(int a, int b) { return L[a] < L[b]; } int x[N], s[N]; vector<int> edge[N]; ll ans; void prepare() { dfs_clock = 0; dep[0] = 0; d[1] = 0; DFS(1, 0); init_LCA(); build(1, 1, n); } void DFS(int u) { Node res = Node(-1, x[u], x[u]); int ql = L[x[u]], qr = R[x[u]]; for (int v: edge[u]) { DFS(v); res = better(res, query(1, 1, n, ql, L[x[v]]-1)); ql = R[x[v]] + 1; } res = better(res, query(1, 1, n, ql, qr)); ans += res.d; } int main() { scanf("%d", &n); init_edge(); for (int i=1; i<n; ++i) { int u, v, w; scanf("%d%d%d", &u, &v, &w); add_edge(u, v, w); add_edge(v, u, w); } prepare(); int q, k; scanf("%d", &q); while (q--) { scanf("%d", &k); int idx; for (int i=1; i<=k; ++i) { scanf("%d", &idx); idx--; Edge &e = edges[idx*2]; if (dep[e.u] > dep[e.v]) x[i] = e.u; else x[i] = e.v; } sort(x+1, x+1+k, cmp); x[0] = 1; int nn = 1; s[nn] = 0; for (int i=0; i<=k; ++i) edge[i].clear(); //s[nn]:0~k, x[s[n]]:1 or x[1~k] for (int i=1; i<=k; ++i) { while (!(L[x[s[nn]]] <= L[x[i]] && R[x[i]] <= R[x[s[nn]]])) nn--; edge[s[nn]].push_back(i); s[++nn] = i; } ans = 0; DFS(0); printf("%lld\n", ans); } return 0; }
编译人生,运行世界!