【YBT2022寒假Day3 C】毒瘤染色(LCT)(圆方树)(容斥)

毒瘤染色

题目链接:YBT2022寒假Day3 C

题目大意

要你在线实现一个操作:
一开始有 n 个点,没有边,然后操作会给你一条边。
如果保证加了之后这个图还是沙漠就加上。
然后每次加完边之后问你一开始所有点都是白色做 k 次每次随机选一个点(可能白色)的点把它变成白色,然后问你分别保留黑白点是的连通块个数的期望值。
(有部分分只用求保留白点的)

思路

考虑分别处理加边的操作和询问。

那沙漠一般来讲要么就用圆方树,要么就直接拆链搞,那这个应该是要前面的那个。
你考虑你其实就是要看两个点之间是否连通,这个好搞,和两个点的路径点是否有环,而且会出现新的环。

那我们不如用 LCT 来弄,记录方点圆点个数(方点是一个环的代表点,圆点就是普通点)
然后如果路径中有方点就是有环,就不能加边。(这里没有必要圆方点交错的那种)

然后考虑询问,考虑连通块如何表示:
在是普通的没有环的图上,它可以表示成点数 - 边数。
那现在有环,但是是沙漠,可以表示成点数 - 边数 + 环数。

那由于每次选的概率每个点相同,每个点是黑色或者白色的概率是相同的。
所以边的概率就是它连接的两个点都存在的概率,环存在的概率就是环上所有点都存在的概率。

那我们设 \(w_x\) 为选 \(x\) 个点,都是白点的概率, \(b_x\) 为都是黑点的概率。然后设 \(f_x=w_x+\omega b_x\)
然后答案就是 \(n*f_1-m*f_2+\sum f_{sz_i}\)\(sz_i\) 为每个环的大小)。

然后你在连边的时候更新 \(m,sz_i\) 即可。(一个是加边了就更新,一个是出现了新的环就更新)
然后接着问题就是 \(w_x,b_x\) 怎么求。

\(w_x\) 很好求,就是不被选到:\((\dfrac{n-x}{n})^k\)
那接着就是求 \(b_x\),考虑容斥求:
\(b_x=\sum\limits_{i=0}^x(-1)^i(\dfrac{n-i}{n})^kC_x^i\)

然后就好了。

代码

#include<cstdio>
#include<vector>
#include<algorithm>
#define ll long long
#define mo 998244353

using namespace std;

ll n, q, k, w, m;
ll x, y, lst, tot;
vector <ll> tmp;

ll b1, b2, invn, f2;
ll jc[100001], inv[100001];
ll b[100001], ans;

ll ksm(ll x, ll y) {
	ll re = 1;
	while (y) {
		if (y & 1) re = re * x % mo;
		x = x * x % mo;
		y >>= 1;
	}
	return re;
}

ll C(ll n, ll m) {
	if (n < m || m < 0) return 0;
	return jc[n] * inv[m] % mo * inv[n - m] % mo;
}

struct LCT {
	ll ls[400001], rs[400001], cir[400001];
	ll sz[400001], cirsz[400001], fa[400001];
	bool lzyc[400001];
	
	void Init() {
		for (ll i = 1; i <= n; i++) sz[i] = cir[i] = cirsz[i] = 1;
	}
	
	bool nrt(ll x) {
		return ls[fa[x]] == x || rs[fa[x]] == x;
	}
	
	bool lrs(ll x) {
		return ls[fa[x]] == x;
	}
	
	void up(ll x) {
		sz[x] = sz[ls[x]] + sz[rs[x]] + 1;
		cirsz[x] = cirsz[ls[x]] + cirsz[rs[x]] + cir[x];
	}
	
	void downc(ll now) {
		lzyc[now] ^= 1; swap(ls[now], rs[now]);
	}
	
	void down(ll now) {
		if (lzyc[now]) {
			if (ls[now]) downc(ls[now]);
			if (rs[now]) downc(rs[now]);
			lzyc[now] = 0;
		}
	}
	
	void down_line(ll now) {
		if (nrt(now)) down_line(fa[now]);
		down(now);
	}
	
	void rotate(ll x) {
		ll y = fa[x], z = fa[y];
		ll b = lrs(x) ? rs[x] : ls[x];
		if (z && nrt(y)) (lrs(y) ? ls[z] : rs[z]) = x;
		if (lrs(x)) rs[x] = y, ls[y] = b;
			else ls[x] = y, rs[y] = b;
		fa[x] = z; fa[y] = x;
		if (b) fa[b] = y;
		up(y);
	}
	
	void Splay(ll x) {
		down_line(x);
		while (nrt(x)) {
			if (nrt(fa[x])) {
				if (lrs(x) == lrs(fa[x])) rotate(fa[x]);
					else rotate(x);
			}
			rotate(x);
		}
		up(x);
	}
	
	void access(ll x) {
		ll lst = 0;
		for (; x; x = fa[x]) {
			Splay(x);
			
			rs[x] = lst; up(x);
			lst = x;
		}
	}
	
	void make_root(ll x) {
		access(x);
		Splay(x);
		downc(x);
	}
	
	ll find_root(ll now) {
		access(now);
		Splay(now);
		down(now);
		while (ls[now]) {
			now = ls[now]; down(now);
		}
		return now;
	}
	
	ll select(ll x, ll y) {
		make_root(x);
		access(y);
		Splay(y);
		return y;
	}
	
	bool link(ll x, ll y) {
		if (find_root(x) == find_root(y)) return 0;
		make_root(x);
		fa[x] = y;
		return 1;
	}
	
	void get_all(ll now) {
		tmp.push_back(now);
		if (ls[now]) get_all(ls[now]);
		if (rs[now]) get_all(rs[now]);
	}
}T;

ll get_B(ll x) {
	ll re = 0, di = 1;
	for (ll i = 0; i <= x; i++) {
		(re += di * C(x, i) % mo * ksm((n - i) * invn % mo, k) % mo) %= mo;
		di = mo - di;
	}
	return re;
}

ll W(ll x) {
	return ksm((n - x) * invn % mo, k);
}

int main() {
//	freopen("graph.in", "r", stdin);
//	freopen("graph.out", "w", stdout);
	
	scanf("%lld %lld %lld %lld", &n, &q, &k, &w); invn = ksm(n, mo - 2); tot = n;
	
	T.Init();
	jc[0] = 1; for (ll i = 1; i <= n; i++) jc[i] = jc[i - 1] * i % mo;
	inv[0] = inv[1] = 1; for (ll i = 2; i <= n; i++) inv[i] = inv[mo % i] * (mo - mo / i) % mo;
	for (ll i = 1; i <= n; i++) inv[i] = inv[i - 1] * inv[i] % mo;
	b1 = get_B(1); b2 = get_B(2);
	
	ans = 1ll * n * (W(1) + w * b1) % mo;
	f2 = (W(2) + w * b2) % mo;
	
	while (q--) {
		scanf("%lld %lld", &x, &y); x ^= lst; y ^= lst;
		
		if (x == y) {
			printf("%lld\n", lst); continue;
		}
		
		if (!T.link(x, y)) {
			ll now = T.select(x, y), sz = T.sz[now];
			if (T.sz[now] == T.cirsz[now]) {
				ans = (ans - f2 + mo) % mo;
				T.sz[++tot] = 1;
				tmp.clear(); T.get_all(now);
				for (ll i = 0; i < sz; i++) {
					T.fa[tmp[i]] = T.ls[tmp[i]] = T.rs[tmp[i]] = 0; T.sz[tmp[i]] = T.cir[tmp[i]] = T.cirsz[tmp[i]] = 1;
					T.link(tmp[i], tot);
				}
				ans = (ans + W(sz)) % mo;
				if (w) ans = (ans + get_B(sz)) % mo;
			}
		}
		else ans = (ans - f2 + mo) % mo;
		
		
		lst = ans;
		printf("%lld\n", lst);
	}
	
	return 0;
}
posted @ 2022-02-08 14:41  あおいSakura  阅读(28)  评论(0编辑  收藏  举报