题解 树套树

题面

二叉查找树(BST)是一种简单的数据结构,本题默认你已经熟悉BST的插入和查询两种操作。

给你一棵树,每个节点有一个BST。有以下两种操作:

  1. \(u,v,k\) :在路径 \((u,v)\) 上每个节点的BST中插入 \(k\)
  2. \(u,k\) :询问节点 \(u\) 的BST中查询 \(k\) 时经过节点的权值和。
    为了避免奇怪的情况发生, 本题每次操作的 \(k\) 是互不相同的。

\(1\le n\le 2\times 10^5\)

题解

想不到吧,40分暴力需要离线扫描线+树剖+treap。

显然不能把这颗 BST 想象成一棵树,需要直接考虑每个节点到根路径上会经过哪些节点。

不难发现,假设当前询问的值为 \(k\),询问时间为 \(t\),则每次跳到 \(k\) 比它小,\(t\) 比它小,\(k\) 最大的点;然后再加上每次跳到 \(k\) 比它大,\(t\) 比它小,\(k\) 最小的点,这些点就是 BST 到根路径上会经过的所有点。

这个东西可以直接线段树维护。由于两者类似,只讲每次跳到 \(k\) 比它大,\(t\) 比它小,\(k\) 最小的点怎么维护。线段树以 \(k\) 为下标,\(t\) 为值,对于每个管辖区间 \([l,r]\) 的节点 \(p\) 记录一个 \(mn,sum\),pushup 时 \(mn\) 直接合并,表示设 \(p\) 的左兄弟最小值为 \(x\),则 \(sum\) 表示 \((l,x)\)\([l,r]\) 中会经过的节点之和(对于左儿子 \(sum\) 没有定义)。需要专门写一个 calc(p,v) 函数表示 \((l,v)\)\([l,r]\) 中会经过的节点之和。这个函数是 \(O(\log n)\) 的,在 pushup 和查询时均会用到。

因此,这个线段树单次操作复杂度时 \(O(n\log^2 n)\) 的。对于树上的情况只需要树上差分+线段树合并即可。

#include <cstdio>
#include <vector>
#include <algorithm>
#define gc (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2) ? EOF : *p1 ++)
#define int unsigned

inline int min(const int x, const int y) {return x < y ? x : y;}
const int INF = 1000000000;
char buf[100000], *p1, *p2;
inline int read() {
	char ch;
	int x = 0;
	while ((ch = gc) < 48);
	do x = x * 10 + ch - 48; while ((ch = gc) >= 48);
	return x;
}

struct Edge {int to, nxt;} e[400005];
int head[200005], fa[200005][18], dep[200005], ans[200005], from[200005], to[200005], a[200005], b[200005], tot;
int rt1[200005], rt2[200005], ls[6000005], rs[6000005], sum[6000005], mn[6000005], cnt, n, q;
bool mark[200005], rev;
std::vector<std::pair<int, int> > ins[200005], del[200005], qry[200005];
inline void AddEdge(int u, int v) {
	e[++ tot].to = v, e[tot].nxt = head[u], head[u] = tot;
}
void dfs(int u) {
	dep[u] = dep[fa[u][0]] + 1;
	for (int i = 1; i <= 17; ++ i) fa[u][i] = fa[fa[u][i - 1]][i - 1];
	for (int i = head[u]; i; i = e[i].nxt)
		if (e[i].to != fa[u][0]) fa[e[i].to][0] = u, dfs(e[i].to);
}
int LCA(int u, int v) {
	if (dep[u] < dep[v]) std::swap(u, v);
	int t = dep[u] - dep[v];
	for (int i = 0; i <= 17; ++ i)
		if (t & 1u << i) u = fa[u][i];
	if (u == v) return u;
	for (long i = 17; i >= 0; -- i)
		if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
	return fa[u][0];
}
int calc(int p, int l, int r, int t) {
	if (!p) return 0;
	if (l == r) return mn[p] < t ? (rev ? b[q - l + 1] : b[l]) : 0;
	int mid = 1ull * l + r >> 1;
	if (mn[ls[p]] < t) return calc(ls[p], l, mid, t) + sum[rs[p]];
	else return calc(rs[p], mid + 1, r, t);
}
void pushup(int p, int l, int r) {
	mn[p] = min(mn[ls[p]], mn[rs[p]]), sum[rs[p]] = calc(rs[p], (l + r >> 1) + 1, r, mn[ls[p]]);
}
void insert(int &p, int l, int r, int v, int t) {
	if (!p) mn[p = ++ cnt] = t;
	if (l == r) return;
	int mid = 1ull * l + r >> 1;
	if (v <= mid) insert(ls[p], l, mid, v, t);
	else insert(rs[p], mid + 1, r, v, t);
	pushup(p, l, r);
}
void remove(int& p, int l, int r, int v) {
	if (l == r) {mn[p] = INF, p = 0; return;}
	int mid = l + r >> 1;
	if (v <= mid) remove(ls[p], l, mid, v);
	else remove(rs[p], mid + 1, r, v);
	if (!ls[p] && !rs[p]) p = 0;
	else pushup(p, l, r);
	pushup(p, l, r);
}
std::pair<int, int> query(int p, int l, int r, int v, int t) {
	if (!p) return std::make_pair(0, t);
	if (v <= l) return std::make_pair(calc(p, l, r, t), min(mn[p], t));
	int mid = l + r >> 1;
	if (v > mid) return query(rs[p], mid + 1, r, v, t);
	std::pair<int, int> ansl = query(ls[p], l, mid, v, t), ansr = query(rs[p], mid + 1, r, v, ansl.second);
	return std::make_pair(ansl.first + ansr.first, ansr.second);
}
void merge(int &u, int v, int l, int r) {
	if (!u || !v) return u |= v, void();
	if (l == r) return mn[u]=min(mn[u],mn[v]),void();
	merge(ls[u], ls[v], l, l + r >> 1);
	merge(rs[u], rs[v], (l + r >> 1) + 1, r);
	pushup(u, l, r);
}
void solve(int u) {
	for (int i = head[u]; i; i = e[i].nxt) if (e[i].to != fa[u][0]) {
		solve(e[i].to);
		rev = false, merge(rt1[u], rt1[e[i].to], 1, q);
		rev = true, merge(rt2[u], rt2[e[i].to], 1, q);
	}
	for (auto i : ins[u]) {
		rev = false, insert(rt1[u], 1, q, i.first, i.second);
		rev = true, insert(rt2[u], 1, q, q - i.first + 1, i.second);
	}
	for (auto i : del[u]) {
		rev = false, remove(rt1[u], 1, q, i.first);
		rev = true, remove(rt2[u], 1, q, q - i.first + 1);
	}
	for (auto i : qry[u]) {
		rev = false, ans[i.second] += query(rt1[u], 1, q, i.first, i.second).first;
		rev = true, ans[i.second] += query(rt2[u], 1, q, q - i.first + 1, i.second).first;
	}
}

signed main() {
	mn[0] = INF;
	n = read(), q = read();
	for (int i = 1, u, v; i < n; ++ i)
		u = read(), v = read(), AddEdge(u, v), AddEdge(v, u);
	dfs(1);
	for (int i = 1; i <= q; ++ i)
		if (read() == 1) from[i] = read(), to[i] = read(), a[i] = b[i] = read();
		else from[i] = read(), a[i] = b[i] = read();
	std::sort(b + 1, b + q + 1);
	for (int i = 1; i <= q; ++ i)
		a[i] = std::lower_bound(b + 1, b + q + 1, a[i]) - b;
	for (int i = 1; i <= q; ++ i)
		if (to[i]) {
			ins[from[i]].push_back(std::make_pair(a[i], i));
			ins[to[i]].push_back(std::make_pair(a[i], i));
			del[fa[LCA(from[i], to[i])][0]].push_back(std::make_pair(a[i], i));
		} else qry[from[i]].push_back(std::make_pair(a[i], i)), mark[i] = true;
	solve(1);
	for (int i = 1; i <= q; ++ i) if (mark[i]) printf("%u\n", ans[i]);
	return 0;
}
posted @ 2022-08-28 19:10  zqs2020  阅读(31)  评论(0编辑  收藏  举报