CodeForces 2033G Sakurako and Chefir 题解
题意:给定一棵树,多次询问在以节点 \(v\) 的 \(k\) 级祖先为根的子树内,从点 \(v\) 出发的最长简单路径。
显然可以将所求路径分为终点在 \(v\) 的子树内和终点在 \(v\) 的子树外,前者容易处理。
考虑其他节点贡献的路径,容易想到枚举 \(v\) 每个合法祖先 \(u\) 作为 LCA,查询 \(v\) 所在的子树与 \(u\) 的其他子树之间的贡献。贡献可以使用主席树处理。具体而言,由树上路径计算方式 \(d = dis_u + dis_v - 2 \cdot dis_{lca}\),设 \(rt_u\) 表示记录节点 \(u\) 的子树内所有点的深度的主席树,查询时减去 \(v\) 所在子树的所有点的深度,并查询最大深度。
于是得到一个 \(O(qn \log n)\) 的做法,考虑优化。
发现每次查询时跳了很多祖先,而对于一个确定的祖先与其确定的子树,没有必要查询多次。使用树上启发式合并:预处理重儿子的答案作为该节点的权值,树剖查询链最大值,跳轻链时再使用前述方法。优化至 \(O(q \log^2 n)\),可以通过。
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
const int inf = 1e9;
int lg[200005];
int n, q;
vector<int> vec[200005];
int rt[200005];
struct node {
int ls, rs, val;
} d[16000005];
int cnt;
static inline int clone(int p) {
d[++cnt] = d[p];
return cnt;
}
static inline int insert(int x, int s, int t, int c, int p) {
p = clone(p);
if (s == t) {
d[p].val += c;
return p;
}
int mid = (s + t) >> 1;
if (x <= mid)
d[p].ls = insert(x, s, mid, c, d[p].ls);
else
d[p].rs = insert(x, mid + 1, t, c, d[p].rs);
d[p].val = d[d[p].ls].val + d[d[p].rs].val;
return p;
}
static inline int merge(int s, int t, int u, int v) {
if (!u || !v)
return clone(u | v);
int p = ++cnt;
if (s == t) {
d[p].val = d[u].val + d[v].val;
return p;
}
int mid = (s + t) >> 1;
d[p].ls = merge(s, mid, d[u].ls, d[v].ls);
d[p].rs = merge(mid + 1, t, d[u].rs, d[v].rs);
d[p].val = d[d[p].ls].val + d[d[p].rs].val;
return p;
}
static inline int query(int s, int t, int u, int v) {
if (s == t)
return d[v].val - d[u].val > 0 ? s : -inf;
int mid = (s + t) >> 1;
if (d[d[v].rs].val - d[d[u].rs].val > 0)
return query(mid + 1, t, d[u].rs, d[v].rs);
return query(s, mid, d[u].ls, d[v].ls);
}
int dep[200005];
int st[200005][20];
int siz[200005];
int son[200005];
int top[200005];
int dfn[200005], dfn_clock;
int nfd[200005];
int pre[200005][20];
static inline void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
st[u][0] = fa;
for (int i = 1; i <= 18; ++i)
st[u][i] = st[st[u][i - 1]][i - 1];
siz[u] = 1;
rt[u] = insert(dep[u], 1, n, 1, rt[u]);
for (auto v : vec[u]) {
if (v == fa)
continue;
dfs(v, u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]])
son[u] = v;
rt[u] = merge(1, n, rt[u], rt[v]);
}
}
static inline void dfs2(int u) {
nfd[dfn[u] = ++dfn_clock] = u;
if (!son[u]) {
pre[dfn[u]][0] = -inf;
return;
}
top[son[u]] = top[u];
dfs2(son[u]);
pre[dfn[u]][0] = query(1, n, rt[son[u]], rt[u]) - 2 * dep[u];
for (auto v : vec[u]) {
if (v == st[u][0] || v == son[u])
continue;
top[v] = v;
dfs2(v);
}
}
static inline int jump(int x, int k) {
for (int i = 0; i <= 18; ++i)
if ((k >> i) & 1)
x = st[x][i];
return x;
}
static inline int LCA(int u, int v) {
if (dep[u] < dep[v])
swap(u, v);
u = jump(u, dep[u] - dep[v]);
if (u == v)
return u;
for (int i = 18; i >= 0; --i)
if (st[u][i] != st[v][i]) {
u = st[u][i];
v = st[v][i];
}
return st[u][0];
}
static inline void build() {
for (int i = 2; i <= n; ++i)
lg[i] = lg[i >> 1] + 1;
for (int j = 1; j <= 18; ++j)
for (int i = 1; i + (1 << j) - 1 <= n; ++i)
pre[i][j] = max(pre[i][j - 1], pre[i + (1 << (j - 1))][j - 1]);
}
static inline int query(int l, int r) {
int len = lg[r - l + 1];
return max(pre[l][len], pre[r - (1 << len) + 1][len]);
}
static inline void solve() {
cin >> n;
for (int i = 1; i <= cnt; ++i)
d[i] = {0, 0, 0};
cnt = dfn_clock = 0;
for (int i = 1; i <= n; ++i) {
vec[i].clear();
rt[i] = son[i] = top[i] = 0;
for (int j = 0; j <= 18; ++j)
st[i][j] = pre[i][j] = 0;
}
for (int i = 1; i < n; ++i) {
int u, v;
cin >> u >> v;
vec[u].push_back(v);
vec[v].push_back(u);
}
dfs(1, 0);
top[1] = 1;
dfs2(1);
build();
cin >> q;
while (q--) {
int x, k;
cin >> x >> k;
int u = st[x][0];
int v = max(jump(x, k), 1);
int ans = min(k, dep[u]);
while (u && dep[u] > dep[v] && top[u] != top[v]) { // DSU on tree
int w = query(1, n, rt[jump(x, dep[x] - dep[u] - 1)], rt[u]);
ans = max(ans, dep[x] + w - 2 * dep[u]);
if (dfn[top[u]] < dep[u])
ans = max(ans, dep[x] + query(dfn[top[u]], dfn[u] - 1));
u = st[top[u]][0];
}
if (dep[u] >= dep[v]) {
int w = query(1, n, rt[jump(x, dep[x] - dep[u] - 1)], rt[u]);
ans = max(ans, dep[x] + w - 2 * dep[u]);
if (dfn[v] < dfn[u])
ans = max(ans, dep[x] + query(dfn[v], dfn[u] - 1));
}
// 这是暴力的实现方式
// int ans = 0;
// int u = st[x][0];
// --k;
// while (u && k >= 0) {
// int w = query(1, n, rt[jump(x, dep[x] - dep[u] - 1)], rt[u]);
// ans = max(ans, dep[x] + w - 2 * dep[u]);
// u = st[u][0];
// --k;
// }
ans = max(ans, query(1, n, 0, rt[x]) - dep[x]);
cout << ans << ' ';
}
cout << endl;
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
solve();
cout.flush();
return 0;
}