【hihocoder1167】高等理论计算机科学 (重链剖分 +树状数组)
Descroption
原题链接给你一棵\(~n~\)个点的树和\(~m~\)条链,求两两相交的链有多少对,两条链相交当且仅当有至少一个公共点。\(~1 \leq n, m \leq 10 ^ 5~\).
Solution
一个很直观的想法是把每一条链路径上的权值\(+1\),然后计算每一条链内多出来的权值为多少,显然这样是错的,因为两条链的交集可能不止有一个点,那么可以把每一条链路径上的点权\(+1\),边权\(-1\),再算多出来多少就好了。然而我不会这个啊。
考虑一个性质:
两条链相交当且仅当一条链的\(LCA\)在另一条链上
至于怎么证明,可以画图推推反例发现找不到,为了方便,设两条链为\(C1,C2\),若\(~LCA_{C1}~\)不在\(~C2~\)内,可以有两种情况:\(~①~\)\(C1~\)和\(C2~\)没有交集.\(~②~\)\(LCA_{C2}~\)在\(~C1~\)上. 这基于树上结点的父亲的唯一性。于是就把每一条链的\(~LCA~\)的权值\(+1\),最后统计每一条链内权值和就好了,注意减去重复的情况。
Code
#include <bits/stdc++.h>
#define For(i, j, k) for (int i = j; i <= k; ++i)
#define Travel(i, u) for (int i = beg[u], v = to[i]; i; v = to[i = nex[i]])
using namespace std;
inline int read() {
int x = 0, p = 1; char c = getchar();
for (; !isdigit(c); c = getchar()) if(c == '-') p = -1;
for (; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
return x * p;
}
inline void File() {
freopen("P1167.in", "r", stdin);
freopen("P1167.out", "w", stdout);
}
typedef long long ll;
const int N = 1e5 + 10;
int e = 1, beg[N], nex[N << 1], to[N << 1], u, v, tmp[N];
int dep[N], dfn[N], top[N], fa[N], son[N], n, m, siz[N];
struct BIT {
int c[N];
inline void update(int x, int v) { for (; x <= n; x += x & -x) c[x] += v; }
inline int query(int x) { int res = 0; for (; x; x -= x & -x) res += c[x]; return res; }
inline int query(int l, int r) { return query(r) - query(l - 1);}
} T;
inline void add(int x, int y) {
to[++ e] = y, nex[e] = beg[x], beg[x] = e;
to[++ e] = x, nex[e] = beg[y], beg[y] = e;
}
inline void dfs1(int u, int f = 0) {
dep[u] = dep[fa[u] = f] + 1, siz[u] = 1;
Travel(i, u) if (v ^ f) {
dfs1(v, u), siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}
int clk = 0;
inline void dfs2(int u) {
dfn[u] = ++ clk, top[u] = son[fa[u]] == u ? top[fa[u]] : u;
if (son[u]) dfs2(son[u]);
Travel(i, u) if (v ^ fa[u] && v ^ son[u]) dfs2(v);
}
inline int lca(int x, int y, int ty) {
int res = 0;
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) swap(x, y);
res += T.query(dfn[top[x]], dfn[x]), x = fa[top[x]];
}
if (dep[x] < dep[y]) swap(x, y);
return ty ? res + T.query(dfn[y], dfn[x]) : y;
}
struct Chain { int x, y, lca; } P[N];
int main() {
File();
n = read(), m = read();
For(i, 2, n) u = read(), v = read(), add(u, v);
dfs1(1), dfs2(1);
For(i, 1, m) {
P[i].x = read(), P[i].y = read();
++ tmp[P[i].lca = lca(P[i].x, P[i].y, 0)];
T.update(dfn[P[i].lca], 1);
}
ll ans = 0;
For(i, 1, m) ans += lca(P[i].x, P[i].y, 1) - 1;
For(i, 1, n) ans -= 1ll * tmp[i] * (tmp[i] - 1) >> 1ll;
cout << ans << endl;
return 0;
}