【hihocoder1167】高等理论计算机科学 (重链剖分 +树状数组)

Descroption

原题链接给你一棵\(~n~\)个点的树和\(~m~\)条链,求两两相交的链有多少对,两条链相交当且仅当有至少一个公共点。\(~1 \leq n, m \leq 10 ^ 5~\).

Solution

一个很直观的想法是把每一条链路径上的权值\(+1\),然后计算每一条链内多出来的权值为多少,显然这样是错的,因为两条链的交集可能不止有一个点,那么可以把每一条链路径上的点权\(+1\),边权\(-1\),再算多出来多少就好了。然而我不会这个啊。

考虑一个性质:

两条链相交当且仅当一条链的\(LCA\)在另一条链上

至于怎么证明,可以画图推推反例发现找不到,为了方便,设两条链为\(C1,C2\),若\(~LCA_{C1}~\)不在\(~C2~\)内,可以有两种情况:\(~①~\)\(C1~\)\(C2~\)没有交集.\(~②~\)\(LCA_{C2}~\)\(~C1~\)上. 这基于树上结点的父亲的唯一性。于是就把每一条链的\(~LCA~\)的权值\(+1\),最后统计每一条链内权值和就好了,注意减去重复的情况。

Code

#include <bits/stdc++.h>
#define For(i, j, k) for (int i = j; i <= k; ++i)
#define Travel(i, u) for (int i = beg[u], v = to[i]; i; v = to[i = nex[i]])
using namespace std;

inline int read() {
	int x = 0, p = 1; char c = getchar();
	for (; !isdigit(c); c = getchar()) if(c == '-') p = -1;
	for (; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
	return x * p;
}

inline void File() {
	freopen("P1167.in", "r", stdin);
	freopen("P1167.out", "w", stdout);
}

typedef long long ll;
const int N = 1e5 + 10;
int e = 1, beg[N], nex[N << 1], to[N << 1], u, v, tmp[N];
int dep[N], dfn[N], top[N], fa[N], son[N], n, m, siz[N];

struct BIT {
	int c[N];
	inline void update(int x, int v) { for (; x <= n; x += x & -x) c[x] += v; }
	inline int query(int x) { int res = 0; for (; x; x -= x & -x) res += c[x]; return res; }
	inline int query(int l, int r) { return query(r) - query(l - 1);}
} T;

inline void add(int x, int y) {
	to[++ e] = y, nex[e] = beg[x], beg[x] = e;
	to[++ e] = x, nex[e] = beg[y], beg[y] = e;
}

inline void dfs1(int u, int f = 0) {
	dep[u] = dep[fa[u] = f] + 1, siz[u] = 1;
	Travel(i, u) if (v ^ f) {
		dfs1(v, u), siz[u] += siz[v];	
		if (siz[v] > siz[son[u]]) son[u] = v;
	}
}

int clk = 0;				
inline void dfs2(int u) {
	dfn[u] = ++ clk, top[u] = son[fa[u]] == u ? top[fa[u]] : u;
	if (son[u]) dfs2(son[u]);
	Travel(i, u) if (v ^ fa[u] && v ^ son[u]) dfs2(v);
}

inline int lca(int x, int y, int ty) {
	int res = 0;
	while (top[x] != top[y]) {
		if (dep[top[x]] < dep[top[y]]) swap(x, y);	
		res += T.query(dfn[top[x]], dfn[x]), x = fa[top[x]];		
	}
	if (dep[x] < dep[y]) swap(x, y);	
	return ty ? res + T.query(dfn[y], dfn[x]) : y;
}

struct Chain { int x, y, lca; } P[N];

int main() {
	File();

	n = read(), m = read();
	For(i, 2, n) u = read(), v = read(), add(u, v); 
	dfs1(1), dfs2(1);

	For(i, 1, m) {
		P[i].x = read(), P[i].y = read();
		++ tmp[P[i].lca = lca(P[i].x, P[i].y, 0)];
		T.update(dfn[P[i].lca], 1);
	}		
	
	ll ans = 0;
	For(i, 1, m) ans += lca(P[i].x, P[i].y, 1) - 1;
	For(i, 1, n) ans -= 1ll * tmp[i] * (tmp[i] - 1) >> 1ll;
	
	cout << ans << endl;	
	return 0;
}

posted @ 2018-09-13 17:26  LSTete  阅读(392)  评论(1编辑  收藏  举报