AtCoder Beginner Contest 336 G 16 Integers

洛谷传送门

AtCoder 传送门

首先考虑只要求构造任意一个符合条件的 \(a\) 怎么做。考虑建图,\((i, j, k, l)\)\(\forall x \in \{0, 1\}, (j, k, l, x)\) 连有向边。那么就是要求固定每个点经过次数的一条哈密顿路径。

但是哈密顿路径仍然不好处理。考虑拆点,把原来的 \((i, j, k, l)\) 看成 \((i, j, k)\)\((j, k, l)\) 连有向边。那么要求固定每条边经过次数的欧拉路径。到这一步直接跑欧拉路径就行了。

考虑如何计数。考虑 BEST 定理,有向欧拉图的本质不同欧拉回路数量(循环同构视为本质相同,每条边互相区分)为:

\[T \prod\limits_{i = 1}^n (out_i - 1)! \]

其中 \(T\) 为图的外向生成树个数(注意到有向欧拉图以每个点为根的外向生成树个数相等),\(out_i\)\(i\) 点的出度。\(T\) 可以用矩阵树定理求得。注意去除孤立点。

但是这题统计的是欧拉路径。考虑原图若存在一对入度小于出度和入度大于出度的点,那么以它们为起点和终点,否则枚举每个点作为起点和终点即可。从终点向起点连一条有向边就转化成了欧拉回路。

注意判一些无解的情况。

时间复杂度 \(O(n^4 + \sum out_i)\),其中 \(n = 8\)

code
// Problem: G - 16 Integers
// Contest: AtCoder - AtCoder Beginner Contest 336
// URL: https://atcoder.jp/contests/abc336/tasks/abc336_g
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 1000100;
const int N = 1000000;
const ll mod = 998244353;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

ll b[19][19], in[19], out[19], fac[maxn], ifac[maxn], fa[19], id[19], di[19], m, a[19][19];

int find(int x) {
	return fa[x] == x ? x : fa[x] = find(fa[x]);
}

inline void merge(int x, int y) {
	x = find(x);
	y = find(y);
	if (x != y) {
		fa[x] = y;
	}
}

inline ll det() {
	ll ans = 1;
	for (int i = 1; i < m; ++i) {
		for (int j = i; j < m; ++j) {
			if (a[j][i]) {
				if (i != j) {
					swap(a[i], a[j]);
					ans = (mod - ans) % mod;
				}
			}
		}
		if (!a[i][i]) {
			return 0;
		}
		ans = ans * a[i][i] % mod;
		ll t = qpow(a[i][i], mod - 2);
		for (int j = i; j < m; ++j) {
			a[i][j] = a[i][j] * t % mod;
		}
		for (int j = 1; j < m; ++j) {
			if (i == j) {
				continue;
			}
			ll t = (mod - a[j][i]) % mod;
			for (int k = i + 1; k < m; ++k) {
				a[j][k] = (a[j][k] + t * a[i][k]) % mod;
			}
		}
	}
	return ans;
}

inline void init() {
	fac[0] = 1;
	for (int i = 1; i <= N; ++i) {
		fac[i] = fac[i - 1] * i % mod;
	}
	ifac[N] = qpow(fac[N], mod - 2);
	for (int i = N - 1; ~i; --i) {
		ifac[i] = ifac[i + 1] * (i + 1) % mod;
	}
}

inline ll calc(int s, int t) {
	mems(a, 0);
	for (int i = 1; i <= m; ++i) {
		int x = i;
		if (i == s || i == m) {
			x ^= (s ^ m);
		}
		a[x][x] = (i == s ? in[di[i]] + 1 : in[di[i]]);
	}
	for (int i = 1; i <= m; ++i) {
		int x = i;
		if (i == s || i == m) {
			x ^= (s ^ m);
		}
		for (int j = 1; j <= m; ++j) {
			int y = j;
			if (j == s || j == m) {
				y ^= (s ^ m);
			}
			a[x][y] = (a[x][y] - b[di[i]][di[j]] + mod) % mod;
			if (i == t && j == s) {
				a[x][y] = (a[x][y] + mod - 1) % mod;
			}
		}
	}
	ll ans = det();
	for (int i = 1; i <= m; ++i) {
		ans = ans * fac[(i == t ? out[di[i]] + 1 : out[di[i]]) - 1] % mod;
	}
	return ans;
}

void solve() {
	for (int i = 0; i < 16; ++i) {
		scanf("%lld", &b[i >> 1][i & 7]);
	}
	for (int i = 0; i < 8; ++i) {
		fa[i] = i;
	}
	for (int i = 0; i < 8; ++i) {
		for (int j = 0; j < 8; ++j) {
			out[j] += b[i][j];
			in[i] += b[i][j];
			if (b[i][j]) {
				merge(i, j);
			}
		}
	}
	int p = -1, s = -1, t = -1;
	for (int i = 0; i < 8; ++i) {
		if (!in[i] && !out[i]) {
			continue;
		}
		id[i] = ++m;
		di[m] = i;
		if (in[i] > out[i]) {
			if (in[i] > out[i] + 1 || t != -1) {
				puts("0");
				return;
			}
			t = m;
		} else if (in[i] < out[i]) {
			if (in[i] < out[i] - 1 || s != -1) {
				puts("0");
				return;
			}
			s = m;
		}
		if (p == -1) {
			p = find(i);
		} else if (p != find(i)) {
			puts("0");
			return;
		}
	}
	ll ans = 0;
	if (s != -1) {
		ans = calc(s, t);
	} else {
		for (int i = 1; i <= m; ++i) {
			ans = (ans + calc(i, i)) % mod;
		}
	}
	for (int i = 0; i < 8; ++i) {
		for (int j = 0; j < 8; ++j) {
			ans = ans * ifac[b[i][j]] % mod;
		}
	}
	printf("%lld\n", ans);
}

int main() {
	init();
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}

posted @ 2024-01-15 09:09  zltzlt  阅读(52)  评论(0编辑  收藏  举报