AtCoder Beginner Contest 334 G Christmas Color Grid 2

洛谷传送门

AtCoder 传送门

考虑相当于把每个标记点的边全部断掉,然后求连通块个数。

考虑一条边 \((u, v)\)(设 \(u < v\))的出现时间,不难发现是 \([1, u - 1] \cup [u + 1, v - 1] \cup [v + 1, n]\)。于是考虑直接套线段树分治和可撤销并查集。

时空复杂度均为 \(O(n^2 \log n)\)。实现得别太随意就能过了。

code
// Problem: G - Christmas Color Grid 2
// Contest: AtCoder - UNIQUE VISION Programming Contest 2023 Christmas (AtCoder Beginner Contest 334)
// URL: https://atcoder.jp/contests/abc334/tasks/abc334_g
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define ID(x, y) (((x) - 1) * m + (y))
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;

const int maxn = 1010;
const int maxm = 1000100;
const ll mod = 998244353;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

ll n, m, ans;
char s[maxn][maxn];
bool vis[maxm];
vector<int> G[maxm];

int fa[maxm], rnk[maxm], t, top;
pair<int*, int> stk[maxm * 10];

int find(int x) {
	return fa[x] == x ? x : find(fa[x]);
}

inline void merge(int x, int y) {
	x = find(x);
	y = find(y);
	if (x == y) {
		return;
	}
	stk[++top] = mkp(&t, t);
	--t;
	if (rnk[x] <= rnk[y]) {
		stk[++top] = mkp(fa + x, fa[x]);
		fa[x] = y;
		if (rnk[x] == rnk[y]) {
			stk[++top] = mkp(rnk + y, rnk[y]);
			++rnk[y];
		}
	} else {
		stk[++top] = mkp(fa + y, fa[y]);
		fa[y] = x;
	}
}

inline void undo() {
	*stk[top].fst = stk[top].scd;
	--top;
}

vector<pii> T[2100000];

void update(int rt, int l, int r, int ql, int qr, pii x) {
	if (ql > qr) {
		return;
	}
	if (ql <= l && r <= qr) {
		T[rt].pb(x);
		return;
	}
	int mid = (l + r) >> 1;
	if (ql <= mid) {
		update(rt << 1, l, mid, ql, qr, x);
	}
	if (qr > mid) {
		update(rt << 1 | 1, mid + 1, r, ql, qr, x);
	}
}

void dfs(int rt, int l, int r) {
	int lst = top;
	for (pii p : T[rt]) {
		merge(p.fst, p.scd);
	}
	if (l == r) {
		if (!vis[l]) {
			ans = (ans + t - 1) % mod;
		}
	} else {
		int mid = (l + r) >> 1;
		dfs(rt << 1, l, mid);
		dfs(rt << 1 | 1, mid + 1, r);
	}
	while (top > lst) {
		undo();
	}
}

void solve() {
	scanf("%lld%lld", &n, &m);
	t = n * m;
	for (int i = 1; i <= n * m; ++i) {
		fa[i] = i;
	}
	for (int i = 1; i <= n; ++i) {
		scanf("%s", s[i] + 1);
	}
	int c0 = 0, c1 = 0;
	for (int i = 1; i <= n; ++i) {
		for (int j = 1; j <= m; ++j) {
			if (s[i][j] == '.') {
				vis[ID(i, j)] = 1;
				++c0;
				continue;
			}
			++c1;
			if (i < n && s[i + 1][j] == '#') {
				G[ID(i, j)].pb(ID(i + 1, j));
			}
			if (j < m && s[i][j + 1] == '#') {
				G[ID(i, j)].pb(ID(i, j + 1));
			}
		}
	}
	for (int i = 1; i <= n * m; ++i) {
		for (int j : G[i]) {
			update(1, 1, n * m, 1, i - 1, mkp(i, j));
			update(1, 1, n * m, i + 1, j - 1, mkp(i, j));
			update(1, 1, n * m, j + 1, n * m, mkp(i, j));
		}
	}
	dfs(1, 1, n * m);
	printf("%lld\n", (ans * qpow(c1, mod - 2) + mod - c0) % mod);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}

posted @ 2023-12-24 11:15  zltzlt  阅读(42)  评论(0编辑  收藏  举报