bzoj5518 & loj3046 「ZJOI2019」语言 线段树合并+树链的并
题目传送门
题解
首先问题就是问有多少条路径是给定的几条路径中的一条的一个子段。
先考虑链的做法。
枚举右端点 \(i\),那么求出 \(j\) 表示经过 \(i\) 的路径,左端点最小是 \(j\),那么右端点 \(i\) 的贡献就是 \(i-j+1\)。
至于求出 \(j\) 可以用直接线性地从右向左扫一遍,在右端点处枚举路径就可以了。
那么问题回到树上。
我们考虑也枚举最终的路径的一个端点。
那么,这个端点的贡献,应该就是经过这个端点的路径的并的长度。所以如果把这个端点看做根的话,那么贡献就是经过这个端点的树链的并。
根据之前做过的 bzoj3991 [SDOI2015] 寻宝游戏 的经验,树链的并的长度的二倍等于按照 dfs 序排序以后,相邻的两个点的距离的和,加上第一个点到最后一个点的距离。
那么,我们只需要能够很快地求出经过一个点 \(x\) 的路径的端点的集合,就可以通过数据结构维护出 \(x\) 的贡献了。
如何计算经过 \(x\) 的路径的端点的集合呢?
很简单,可以使用树上差分,对于路径 \(x \longleftrightarrow lca \longleftrightarrow y\),在 \(x\) 的集合中放上 \(x, y\) 两个点,在 \(y\) 的集合中放上 \(x, y\) 两个点,最后在 \(fa[lca]\) 中删去 \(x, y\)。然后使用线段树合并可以把集合递交给父节点。
感受:ZJOI 竟然有签到题。
如果使用 RMQ 求解 LCA,那么时间复杂度 \(O(n\log n)\)。
#include<bits/stdc++.h>
#define fec(i, x, y) (int i = head[x], y = g[i].to; i; i = g[i].ne, y = g[i].to)
#define dbg(...) fprintf(stderr, __VA_ARGS__)
#define File(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)
#define fi first
#define se second
#define pb push_back
template<typename A, typename B> inline char smax(A &a, const B &b) {return a < b ? a = b , 1 : 0;}
template<typename A, typename B> inline char smin(A &a, const B &b) {return b < a ? a = b , 1 : 0;}
typedef long long ll; typedef unsigned long long ull; typedef std::pair<int, int> pii;
template<typename I> inline void read(I &x) {
int f = 0, c;
while (!isdigit(c = getchar())) c == '-' ? f = 1 : 0;
x = c & 15;
while (isdigit(c = getchar())) x = (x << 1) + (x << 3) + (c & 15);
f ? x = -x : 0;
}
const int N = 1e5 + 7;
const int LOG = 18;
int n, m, dfc, dfc2, nod;
ll ans;
int f[N], dfn[N], pre[N], seq[N << 1], dfn2[N], lc[N << 1][LOG], dep[N];
int rt[N];
struct Edge { int to, ne; } g[N << 1]; int head[N], tot;
inline void addedge(int x, int y) { g[++tot].to = y, g[tot].ne = head[x], head[x] = tot; }
inline void adde(int x, int y) { addedge(x, y), addedge(y, x); }
inline void dfs1(int x, int fa = 0) {
f[x] = fa, dfn[x] = ++dfc, dfn2[x] = ++dfc2, pre[dfc] = seq[dfc2] = x, dep[x] = dep[fa] + 1;
for fec(i, x, y) if (y != fa) dfs1(y, x), seq[++dfc2] = x;
}
inline void rmq_init() {
for (int i = 1; i <= dfc2; ++i) lc[i][0] = seq[i];
for (int j = 1; (1 << j) <= dfc2; ++j)
for (int i = 1; i + (1 << j) - 1 <= dfc2; ++i) {
int a = lc[i][j - 1], b = lc[i + (1 << (j - 1))][j - 1];
lc[i][j] = dep[a] < dep[b] ? a : b;
}
}
inline int qmin(int l, int r) {
int k = std::__lg(r - l + 1), a = lc[l][k], b = lc[r - (1 << k) + 1][k];
return dep[a] < dep[b] ? a : b;
}
inline int lca(int x, int y) { return dfn2[x] < dfn2[y] ? qmin(dfn2[x], dfn2[y]) : qmin(dfn2[y], dfn2[x]); }
inline int dist(int x, int y) { return dep[x] + dep[y] - (dep[lca(x, y)] << 1); }
struct Node { int lc, rc, val, s, ls, rs; } t[N * 120];
inline void pushup(int o) {
if (t[t[o].lc].ls) t[o].ls = t[t[o].lc].ls; else t[o].ls = t[t[o].rc].ls;
if (t[t[o].rc].rs) t[o].rs = t[t[o].rc].rs; else t[o].rs = t[t[o].lc].rs;
t[o].val = t[t[o].lc].val + t[t[o].rc].val;
if (t[t[o].lc].rs && t[t[o].rc].ls) t[o].val += dist(t[t[o].lc].rs, t[t[o].rc].ls);
t[o].s = t[t[o].lc].s + t[t[o].rc].s;
// dbg("o = %d, t[o].lc = %d, t[o].rc = %d, t[o].ls = %d, t[o].rs = %d, t[o].val = %d, t[o].s = %d\n", o, t[o].lc, t[o].rc, t[o].ls, t[o].rs, t[o].val, t[o].s);
assert((!!t[o].ls) == (!!t[o].rs));
if (t[o].ls) assert(!((t[o].val + dist(t[o].ls, t[o].rs)) & 1));
// assert((!!t[o].s) == (!!t[o].ls));
}
inline void ins(int &o, int L, int R, int x, int k) {
if (!o) o = ++nod;
t[o].s += k;
if (L == R) return (void)(t[o].ls = t[o].rs = t[o].s ? pre[L] : 0);
int M = (L + R) >> 1;
if (x <= M) ins(t[o].lc, L, M, x, k);
else ins(t[o].rc, M + 1, R, x, k);
pushup(o);
}
inline int merge(int o, int p) {
if (!o || !p) return o ^ p;
t[o].lc = merge(t[o].lc, t[p].lc);
t[o].rc = merge(t[o].rc, t[p].rc);
if (t[o].lc || t[o].rc) pushup(o);
else t[o].s = t[o].s + t[p].s, t[o].ls = t[o].rs = t[o].s ? t[o].ls | t[p].ls : 0;
return o;
}
inline void debug(int o, int L, int R) {
// dbg("o = %d, L = %d, R = %d, t[o].lc = %d, t[o].rc = %d, t[o].ls = %d, t[o].rs = %d, t[o].val = %d, t[o].s = %d\n", o, L, R, t[o].lc, t[o].rc, t[o].ls, t[o].rs, t[o].val, t[o].s);
assert(t[o].s >= 0);
assert((!!t[o].s) == !!(t[o].ls));
if (L == R) return;
int M = (L + R) >> 1;
debug(t[o].lc, L, M);
debug(t[o].rc, M + 1, R);
}
inline void dfs2(int x, int fa = 0) {
for fec(i, x, y) if (y != fa) dfs2(y, x), rt[x] = merge(rt[x], rt[y]);
ans += (t[rt[x]].val + dist(t[rt[x]].ls, t[rt[x]].rs)) / 2;
// dbg("****** x = %d, ls = %d, rs = %d, dif = %d, %d, %d\n", x, t[rt[x]].ls, t[rt[x]].rs, (t[rt[x]].val + dist(t[rt[x]].ls, t[rt[x]].rs)) / 2, t[rt[x]].val, dist(t[rt[x]].ls, t[rt[x]].rs));
// debug(rt[x], 1, n);
assert(!((t[rt[x]].val + dist(t[rt[x]].ls, t[rt[x]].rs)) & 1));
}
inline void work() {
dfs2(1);
printf("%lld\n", ans / 2);
}
inline void init() {
read(n), read(m);
int x, y;
for (int i = 1; i < n; ++i) read(x), read(y), adde(x, y);
dfs1(1), rmq_init();
// for (int i = 1; i <= n; ++i) dbg("i = %d, dfn[i] = %d, dfn2[i] = %d\n", i, dfn[i], dfn2[i]);
for (int i = 1; i <= m; ++i) {
int x, y, p;
read(x), read(y);
p = lca(x, y);
// dbg("x = %d, y = %d, p = %d\n", x, y, p);
ins(rt[x], 1, n, dfn[y], 1), ins(rt[x], 1, n, dfn[x], 1);
ins(rt[y], 1, n, dfn[x], 1), ins(rt[y], 1, n, dfn[y], 1);
if (f[p]) ins(rt[f[p]], 1, n, dfn[x], -2), ins(rt[f[p]], 1, n, dfn[y], -2);
}
// dbg("****************** %d\n", lc[1][1]);
}
int main() {
#ifdef hzhkk
freopen("hkk.in", "r", stdin);
#endif
init();
work();
fclose(stdin), fclose(stdout);
return 0;
}