WC2024 线段树

洛谷传送门

若一个结点 \([l_i, r_i)\) 已知就连边 \((l_i, r_i)\),那么子集满足条件当且仅当每对 \((L_i, R_i)\) 都连通。

考虑在树形结构上 dp。发现若 \(l, r\) 不连通,设 \(l\) 所在连通块点编号最大值为 \(i\),那么 \(r\) 所在连通块点编号最小值 \(> i\)

于是设 \(f_{u, i}\)\(u\) 结点 \(l, r\) 不连通,\(l\) 所在连通块点编号最大值为 \(i\) 且满足子树中所有条件的方案数(在子树中定义为 \(L_i, R_i\) 至少有一个在子树中),\(g_u\)\(u\) 结点 \(l, r\) 连通且满足子树中所有条件的方案数。考虑 \(l, r\) 之间是否连边,就有转移(设 \(v\)\(u\) 的左儿子,\(w\)\(u\) 的右儿子):

\[2 g_v g_w \to g_u \]

\[g_v f_{w, i} \to g_u, g_v f_{w, i} \to f_{u, i} \]

\[g_w f_{v, i} \to g_u, g_w f_{v, i} \to f_{u, i} \]

\[f_{v, i} f_{w, j} \to g_u, f_{v, i} f_{w, j} \to f_{u, i/j} \]

第四种转移要注意,\([i + 1, j]\) 的部分会和外界不连通,所以不能存在 \((L_i, R_i)\) 使得一个在 \([i + 1, j]\) 里面一个不在。然后转移到 \(f_{u, i}\)\(f_{u, j}\) 都是满足状态定义的,可以任意转移到一个。

考虑异或哈希,对每对 \((L_i, R_i)\),让 \(a_{L_i}, a_{R_i}\) 都异或上一个随机数,然后对 \(a\) 做一遍前缀异或,那么第四种转移的充要条件为 \(a_i = a_j\)

此时我们发现 \(a_i\) 相等的状态可以放在一起考虑。若把 \(i\) 改写成 \(a_i\),那么:

\[2 g_v g_w \to g_u \]

\[g_v f_{w, i} \to g_u, g_v f_{w, i} \to f_{u, i} \]

\[g_w f_{v, i} \to g_u, g_w f_{v, i} \to f_{u, i} \]

\[f_{v, i} f_{w, i} \to g_u, f_{v, i} f_{w, i} \to f_{u, i} \]

线段树合并维护即可。在叶子处合并 dp 值,若走到了非公共结点就打个乘法 tag。

时间复杂度 \(O(n \log n)\),默认 \(n, m\) 同阶。

code
// Problem: P10145 [WC/CTS2024] 线段树
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P10145
// Memory Limit: 512 MB
// Time Limit: 3000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 400100;
const ll mod = 998244353;

ll n, m, p[maxn], nt, ls[maxn], rs[maxn], tot, g[maxn];
ull a[maxn], lsh[maxn];
pii b[maxn];
mt19937_64 rnd(chrono::steady_clock::now().time_since_epoch().count());
int rt[maxn];

int build(int l, int r, int &k) {
	int u = ++nt;
	b[u] = mkp(l, r);
	if (l + 1 == r) {
		return u;
	}
	int mid = p[++k];
	ls[u] = build(l, mid, k);
	rs[u] = build(mid, r, k);
	return u;
}

namespace SGT {
	int ls[maxn * 20], rs[maxn * 20];
	ll tag[maxn * 20], sum[maxn * 20];
	
	inline void pushup(int x) {
		sum[x] = (sum[ls[x]] + sum[rs[x]]) % mod;
	}
	
	inline void pushtag(int x, ll y) {
		if (!x) {
			return;
		}
		sum[x] = sum[x] * y % mod;
		tag[x] = tag[x] * y % mod;
	}
	
	inline void pushdown(int x) {
		if (tag[x] == 1) {
			return;
		}
		pushtag(ls[x], tag[x]);
		pushtag(rs[x], tag[x]);
		tag[x] = 1;
	}
	
	void update(int &rt, int l, int r, int x) {
		if (!rt) {
			rt = ++nt;
			tag[rt] = 1;
		}
		if (l == r) {
			sum[rt] = 1;
			return;
		}
		pushdown(rt);
		int mid = (l + r) >> 1;
		(x <= mid) ? update(ls[rt], l, mid, x) : update(rs[rt], mid + 1, r, x);
		pushup(rt);
	}
	
	ll query(int rt, int l, int r, int x) {
		if (!rt) {
			return 0;
		}
		if (l == r) {
			return sum[rt];
		}
		pushdown(rt);
		int mid = (l + r) >> 1;
		return x <= mid ? query(ls[rt], l, mid, x) : query(rs[rt], mid + 1, r, x);
	}
	
	int merge(int u, int v, ll x, ll y, int l, int r) {
		if (!u) {
			pushtag(v, y);
			return v;
		}
		if (!v) {
			pushtag(u, x);
			return u;
		}
		if (l == r) {
			sum[u] = (sum[u] * x + sum[v] * y + sum[u] * sum[v]) % mod;
			return u;
		}
		pushdown(u);
		pushdown(v);
		int mid = (l + r) >> 1;
		ls[u] = merge(ls[u], ls[v], x, y, l, mid);
		rs[u] = merge(rs[u], rs[v], x, y, mid + 1, r);
		pushup(u);
		return u;
	}
}

void dfs(int u) {
	if (!ls[u]) {
		SGT::update(rt[u], 1, tot, a[b[u].fst]);
		g[u] = 1;
		return;
	}
	int v = ls[u], w = rs[u];
	dfs(v);
	dfs(w);
	rt[u] = SGT::merge(rt[v], rt[w], g[w], g[v], 1, tot);
	g[u] = (g[v] * g[w] * 2 + SGT::sum[rt[u]]) % mod;
}

void solve() {
	scanf("%lld%lld", &n, &m);
	for (int i = 1; i < n; ++i) {
		scanf("%lld", &p[i]);
	}
	while (m--) {
		int l, r;
		scanf("%d%d", &l, &r);
		ull x = rnd();
		a[l] ^= x;
		a[r] ^= x;
	}
	for (int i = 1; i < n; ++i) {
		a[i] ^= a[i - 1];
	}
	int k = 0;
	build(0, n, k);
	bool fl = 0;
	for (int i = 0; i < n; ++i) {
		lsh[++tot] = a[i];
		fl |= (a[i] == 0);
	}
	sort(lsh + 1, lsh + tot + 1);
	tot = unique(lsh + 1, lsh + tot + 1) - lsh - 1;
	for (int i = 0; i < n; ++i) {
		a[i] = lower_bound(lsh + 1, lsh + tot + 1, a[i]) - lsh;
	}
	dfs(1);
	printf("%lld\n", (g[1] + (fl ? SGT::query(rt[1], 1, tot, 1) : 0)) % mod);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}

posted @ 2024-02-13 18:48  zltzlt  阅读(48)  评论(0编辑  收藏  举报