hihoCoder #1954 : 压缩树(虚树)

题意

有一棵 \(n\) 个节点且以 \(1\) 为根的树,把它复制成 \(m\) 个版本,有 \(q\) 次操作,每次对 \([l, r]\) 这些版本的 \(v\) 节点到根的路径收缩起来。

收缩 \(v\) 也就是把 \(v\) 到根路径上(除了根)所有点的父亲都变成根。

最后查询每个版本的每个点的 \(dep\) 之和。

数据范围

\(n, m, q \le 2 \times 10^5\)

题解

操作顺序是无所谓的,我们假设操作了点集 \(S\) ,那么最后被缩上去的点其实就是 \(\{S, root\}\) 构成虚树经过的节点。

每个点的深度其实它原来的深度减去它到根(除了根与根的儿子)被缩的点的个数。

考虑祖先对它的贡献是比较麻烦的,我们不妨考虑它对于祖先的贡献,其实就是每个深度 \(\ge 2\) 的节点的子树 \(size\) 之和。

那么我们把操作离线,只需要动态维护虚树经过所有点的权值和。

这其实是一个经典的动态虚树的问题,按照 \(dfs\) 序,用 std :: set 维护当前的点集,假设插入点为 \(k\) 找到它的前驱 \(l\) 与后继 \(r\) ,令 \(\mathrm{LCA}(l, k), \mathrm{LCA}(r, k)\) 深度较大点为 \(f\) ,那么这次新产生的路径是 \((k, f)\) 的路径(注意 \(f\) 原来就是存在于虚树中的,需要去掉),删除是类似的。

注意可能一个点被缩多次,我们需要利用 std :: multiset ,然后每次插入删除的时候查找是否还存在即可。

复杂度是 \(\mathcal O((n + q) \log n + m)\) 的。

代码

#include <bits/stdc++.h>

#define For(i, l, r) for (register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for (register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Rep(i, r) for (register int i = (0), i##end = (int)(r); i < i##end; ++i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << (x) << endl

using namespace std;

typedef long long ll;
typedef pair<int, int> PII;

template<typename T> inline bool chkmin(T &a, T b) { return b < a ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return b > a ? a = b, 1 : 0; }

inline int read() {
	int x(0), sgn(1); char ch(getchar());
	for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
	for (; isdigit(ch); ch = getchar()) x = (x * 10) + (ch ^ 48);
	return x * sgn;
}

void File() {
#ifdef zjp_shadow
	freopen ("1954.in", "r", stdin);
	freopen ("1954.out", "w", stdout);
#endif
}

const int N = 2e5 + 1e3;

vector<int> G[N];

ll ans = 0, sum[N];

int n, m, q, dep[N], anc[N][20], Log2[N], sz[N], dfn[N];

void Dfs_Init(int u, int fa = 0) {
	static int clk = 0; dfn[u] = ++ clk;
	dep[u] = dep[anc[u][0] = fa] + 1;
	ans += dep[u]; sz[u] = 1;
	for (int v : G[u]) if (v != fa) Dfs_Init(v, u), sz[u] += sz[v];
}

void Get_Sum(int u, int fa = 0) {
	sum[u] = sum[fa] + (dep[u] > 1) * sz[u];
	for (int v : G[u]) if (v != fa) Get_Sum(v, u);
}

struct Cmp {
	inline bool operator () (const int &lhs, const int &rhs) { return dfn[lhs] < dfn[rhs]; }
};

vector<int> add[N], del[N]; multiset<int, Cmp> S;

inline int Lca(int x, int y) {
	if (dep[x] < dep[y]) swap(x, y);
	int gap = dep[x] - dep[y];
	For (i, 0, Log2[gap]) 
		if (gap >> i & 1) x = anc[x][i];
	if (x == y) return x;
	Fordown (i, Log2[dep[x]], 0)
		if (anc[x][i] != anc[y][i]) x = anc[x][i], y = anc[y][i];
	return anc[x][0];
}

int Find(int x) {
	PII res; auto it = S.upper_bound(x);
	if (it != S.end()) {
		int tmp = Lca(*it, x); chkmax(res, {dep[tmp], tmp});
	}
	if (it != S.begin()) {
		int tmp = Lca(*prev(it), x); chkmax(res, {dep[tmp], tmp});
	}
	return res.second ? res.second : x;
}

int main () {

	File();

	n = read(); m = read(); q = read();
	For (i, 1, n - 1) {
		int u = read(), v = read();
		G[u].push_back(v); G[v].push_back(u);
	}

	while (q --) {
		int l = read(), r = read(), v = read();
		add[l].push_back(v); del[r + 1].push_back(v);
	}

	dep[0] = -1; Dfs_Init(1); Get_Sum(1);

	For (i, 2, n) Log2[i] = Log2[i >> 1] + 1;
	For (j, 1, Log2[n]) For (i, 1, n)
		anc[i][j] = anc[anc[i][j - 1]][j - 1];

	S.insert(1);
	For (i, 1, m) {
		for (int x : add[i]) {
			if (S.find(x) == S.end())
				ans -= sum[x] - sum[Find(x)];
			S.insert(x);
		}
		for (int x : del[i]) {
			S.erase(S.find(x));
			if (S.find(x) == S.end())
				ans += sum[x] - sum[Find(x)];
		}
		printf ("%lld%c", ans, i == iend ? '\n' : ' ');
	}

	return 0;

}
posted @ 2019-04-17 21:12  zjp_shadow  阅读(427)  评论(0编辑  收藏  举报