AtCoder Beginner Contest 334 G Christmas Color Grid 2
考虑相当于把每个标记点的边全部断掉,然后求连通块个数。
考虑一条边 \((u, v)\)(设 \(u < v\))的出现时间,不难发现是 \([1, u - 1] \cup [u + 1, v - 1] \cup [v + 1, n]\)。于是考虑直接套线段树分治和可撤销并查集。
时空复杂度均为 \(O(n^2 \log n)\)。实现得别太随意就能过了。
code
// Problem: G - Christmas Color Grid 2
// Contest: AtCoder - UNIQUE VISION Programming Contest 2023 Christmas (AtCoder Beginner Contest 334)
// URL: https://atcoder.jp/contests/abc334/tasks/abc334_g
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define ID(x, y) (((x) - 1) * m + (y))
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;
const int maxn = 1010;
const int maxm = 1000100;
const ll mod = 998244353;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, m, ans;
char s[maxn][maxn];
bool vis[maxm];
vector<int> G[maxm];
int fa[maxm], rnk[maxm], t, top;
pair<int*, int> stk[maxm * 10];
int find(int x) {
return fa[x] == x ? x : find(fa[x]);
}
inline void merge(int x, int y) {
x = find(x);
y = find(y);
if (x == y) {
return;
}
stk[++top] = mkp(&t, t);
--t;
if (rnk[x] <= rnk[y]) {
stk[++top] = mkp(fa + x, fa[x]);
fa[x] = y;
if (rnk[x] == rnk[y]) {
stk[++top] = mkp(rnk + y, rnk[y]);
++rnk[y];
}
} else {
stk[++top] = mkp(fa + y, fa[y]);
fa[y] = x;
}
}
inline void undo() {
*stk[top].fst = stk[top].scd;
--top;
}
vector<pii> T[2100000];
void update(int rt, int l, int r, int ql, int qr, pii x) {
if (ql > qr) {
return;
}
if (ql <= l && r <= qr) {
T[rt].pb(x);
return;
}
int mid = (l + r) >> 1;
if (ql <= mid) {
update(rt << 1, l, mid, ql, qr, x);
}
if (qr > mid) {
update(rt << 1 | 1, mid + 1, r, ql, qr, x);
}
}
void dfs(int rt, int l, int r) {
int lst = top;
for (pii p : T[rt]) {
merge(p.fst, p.scd);
}
if (l == r) {
if (!vis[l]) {
ans = (ans + t - 1) % mod;
}
} else {
int mid = (l + r) >> 1;
dfs(rt << 1, l, mid);
dfs(rt << 1 | 1, mid + 1, r);
}
while (top > lst) {
undo();
}
}
void solve() {
scanf("%lld%lld", &n, &m);
t = n * m;
for (int i = 1; i <= n * m; ++i) {
fa[i] = i;
}
for (int i = 1; i <= n; ++i) {
scanf("%s", s[i] + 1);
}
int c0 = 0, c1 = 0;
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
if (s[i][j] == '.') {
vis[ID(i, j)] = 1;
++c0;
continue;
}
++c1;
if (i < n && s[i + 1][j] == '#') {
G[ID(i, j)].pb(ID(i + 1, j));
}
if (j < m && s[i][j + 1] == '#') {
G[ID(i, j)].pb(ID(i, j + 1));
}
}
}
for (int i = 1; i <= n * m; ++i) {
for (int j : G[i]) {
update(1, 1, n * m, 1, i - 1, mkp(i, j));
update(1, 1, n * m, i + 1, j - 1, mkp(i, j));
update(1, 1, n * m, j + 1, n * m, mkp(i, j));
}
}
dfs(1, 1, n * m);
printf("%lld\n", (ans * qpow(c1, mod - 2) + mod - c0) % mod);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}